From 8406a416a3ae49256e4ef612745ee42a48b06de4 Mon Sep 17 00:00:00 2001 From: Shimin Date: Tue, 9 Jul 2024 09:51:12 +0800 Subject: [PATCH 01/13] feat(wren-ui): Add References UI (#489) * feat(wren-ui): add clsx package * feat(wren-ui): rename promptThread to thread * feat(wren-ui): add references side float to thread * feat(wren-ui): show actual count for show all button --- wren-ui/package.json | 1 + .../pages/home/promptThread/AnswerResult.tsx | 148 ------------ .../pages/home/thread/AnswerResult.tsx | 163 ++++++++++++++ .../CollapseContent.tsx | 0 .../{promptThread => thread}/StepContent.tsx | 2 +- .../thread/feedback/ReferenceSideFloat.tsx | 213 ++++++++++++++++++ .../pages/home/thread/feedback/index.tsx | 103 +++++++++ .../pages/home/thread/feedback/utils.tsx | 27 +++ .../home/{promptThread => thread}/index.tsx | 14 +- wren-ui/src/pages/home/[id].tsx | 7 +- wren-ui/src/styles/components/tag.less | 16 ++ wren-ui/src/styles/utilities/text.less | 3 + wren-ui/src/utils/icons.ts | 6 +- wren-ui/yarn.lock | 5 + 14 files changed, 544 insertions(+), 164 deletions(-) delete mode 100644 wren-ui/src/components/pages/home/promptThread/AnswerResult.tsx create mode 100644 wren-ui/src/components/pages/home/thread/AnswerResult.tsx rename wren-ui/src/components/pages/home/{promptThread => thread}/CollapseContent.tsx (100%) rename wren-ui/src/components/pages/home/{promptThread => thread}/StepContent.tsx (98%) create mode 100644 wren-ui/src/components/pages/home/thread/feedback/ReferenceSideFloat.tsx create mode 100644 wren-ui/src/components/pages/home/thread/feedback/index.tsx create mode 100644 wren-ui/src/components/pages/home/thread/feedback/utils.tsx rename wren-ui/src/components/pages/home/{promptThread => thread}/index.tsx (90%) diff --git a/wren-ui/package.json b/wren-ui/package.json index 33f1dc17b..e20f77c13 100644 --- a/wren-ui/package.json +++ b/wren-ui/package.json @@ -60,6 +60,7 @@ "@typescript-eslint/parser": "6.18.0", "ace-builds": "^1.32.3", "antd": "4.20.4", + "clsx": "^2.1.1", "dayjs": "^1.11.11", "duckdb": "^0.10.1", "duckdb-async": "^0.10.0", diff --git a/wren-ui/src/components/pages/home/promptThread/AnswerResult.tsx b/wren-ui/src/components/pages/home/promptThread/AnswerResult.tsx deleted file mode 100644 index 4d1910b9b..000000000 --- a/wren-ui/src/components/pages/home/promptThread/AnswerResult.tsx +++ /dev/null @@ -1,148 +0,0 @@ -import { useState } from 'react'; -import Link from 'next/link'; -import { Col, Button, Row, Skeleton, Typography } from 'antd'; -import styled from 'styled-components'; -import { Path } from '@/utils/enum'; -import CheckCircleFilled from '@ant-design/icons/CheckCircleFilled'; -import QuestionCircleOutlined from '@ant-design/icons/QuestionCircleOutlined'; -import SaveOutlined from '@ant-design/icons/SaveOutlined'; -import FileDoneOutlined from '@ant-design/icons/FileDoneOutlined'; -import StepContent from '@/components/pages/home/promptThread/StepContent'; - -const { Title, Text } = Typography; - -const StyledAnswer = styled(Typography)` - position: relative; - border: 1px var(--gray-4) solid; - border-radius: 4px; - - .adm-answer-title { - font-weight: 500; - position: absolute; - top: -13px; - left: 8px; - background: white; - } -`; - -const StyledQuestion = styled(Row)` - padding: 4px 8px; - border-radius: 4px; - color: var(--gray-6); - background-color: var(--gray-3); - margin-bottom: 8px; - font-size: 14px; - - &:hover { - background-color: var(--gray-4) !important; - cursor: pointer; - } -`; - -interface Props { - loading: boolean; - question: string; - description: string; - answerResultSteps: Array<{ - summary: string; - sql: string; - }>; - fullSql: string; - threadResponseId: number; - onOpenSaveAsViewModal: (data: { sql: string; responseId: number }) => void; - isLastThreadResponse: boolean; - onTriggerScrollToBottom: () => void; - summary: string; - view?: { - id: number; - displayName: string; - }; -} - -export default function AnswerResult(props: Props) { - const { - loading, - question, - description, - answerResultSteps, - fullSql, - threadResponseId, - isLastThreadResponse, - onOpenSaveAsViewModal, - onTriggerScrollToBottom, - summary, - view, - } = props; - - const isViewSaved = !!view; - - const [ellipsis, setEllipsis] = useState(true); - - return ( - - setEllipsis(!ellipsis)}> - - - Question: - - - - {question} - - - - - {summary} - - - - - Summary - -
{description}
- {(answerResultSteps || []).map((step, index) => ( - - ))} -
- {isViewSaved ? ( -
- - Generated from saved view{' '} - - {view.displayName} - -
- ) : ( - - )} -
- ); -} diff --git a/wren-ui/src/components/pages/home/thread/AnswerResult.tsx b/wren-ui/src/components/pages/home/thread/AnswerResult.tsx new file mode 100644 index 000000000..0a8aa90aa --- /dev/null +++ b/wren-ui/src/components/pages/home/thread/AnswerResult.tsx @@ -0,0 +1,163 @@ +import { useState } from 'react'; +import Link from 'next/link'; +import { Col, Button, Row, Skeleton, Typography } from 'antd'; +import styled from 'styled-components'; +import { Path } from '@/utils/enum'; +import CheckCircleFilled from '@ant-design/icons/CheckCircleFilled'; +import QuestionCircleOutlined from '@ant-design/icons/QuestionCircleOutlined'; +import SaveOutlined from '@ant-design/icons/SaveOutlined'; +import FileDoneOutlined from '@ant-design/icons/FileDoneOutlined'; +import StepContent from '@/components/pages/home/thread/StepContent'; +import FeedbackLayout from '@/components/pages/home/thread/feedback'; + +const { Title, Text } = Typography; + +const Wrapper = styled.div` + width: 680px; +`; + +const StyledAnswer = styled(Typography)` + position: relative; + border: 1px var(--gray-4) solid; + border-radius: 4px; + + .adm-answer-title { + font-weight: 500; + position: absolute; + top: -13px; + left: 8px; + background: white; + } +`; + +const StyledQuestion = styled(Row)` + padding: 4px 8px; + border-radius: 4px; + color: var(--gray-6); + background-color: var(--gray-3); + margin-bottom: 8px; + font-size: 14px; + + &:hover { + background-color: var(--gray-4) !important; + cursor: pointer; + } +`; + +interface Props { + loading: boolean; + question: string; + description: string; + answerResultSteps: Array<{ + summary: string; + sql: string; + }>; + fullSql: string; + threadResponseId: number; + onOpenSaveAsViewModal: (data: { sql: string; responseId: number }) => void; + isLastThreadResponse: boolean; + onTriggerScrollToBottom: () => void; + summary: string; + view?: { + id: number; + displayName: string; + }; +} + +export default function AnswerResult(props: Props) { + const { + loading, + question, + description, + answerResultSteps, + fullSql, + threadResponseId, + isLastThreadResponse, + onOpenSaveAsViewModal, + onTriggerScrollToBottom, + summary, + view, + } = props; + + const isViewSaved = !!view; + + const [ellipsis, setEllipsis] = useState(true); + + return ( + + + setEllipsis(!ellipsis)}> + + + Question: + + + + {question} + + + + + {summary} + + + } + bodySlot={ + + + + + Summary + +
{description}
+ {(answerResultSteps || []).map((step, index) => ( + + ))} +
+ {isViewSaved ? ( +
+ + Generated from saved view{' '} + + {view.displayName} + +
+ ) : ( + + )} +
+ } + /> +
+ ); +} diff --git a/wren-ui/src/components/pages/home/promptThread/CollapseContent.tsx b/wren-ui/src/components/pages/home/thread/CollapseContent.tsx similarity index 100% rename from wren-ui/src/components/pages/home/promptThread/CollapseContent.tsx rename to wren-ui/src/components/pages/home/thread/CollapseContent.tsx diff --git a/wren-ui/src/components/pages/home/promptThread/StepContent.tsx b/wren-ui/src/components/pages/home/thread/StepContent.tsx similarity index 98% rename from wren-ui/src/components/pages/home/promptThread/StepContent.tsx rename to wren-ui/src/components/pages/home/thread/StepContent.tsx index 905c86f75..2eb724f49 100644 --- a/wren-ui/src/components/pages/home/promptThread/StepContent.tsx +++ b/wren-ui/src/components/pages/home/thread/StepContent.tsx @@ -4,7 +4,7 @@ import FunctionOutlined from '@ant-design/icons/FunctionOutlined'; import { BinocularsIcon } from '@/utils/icons'; import CollapseContent, { Props as CollapseContentProps, -} from '@/components/pages/home/promptThread/CollapseContent'; +} from '@/components/pages/home/thread/CollapseContent'; import useAnswerStepContent from '@/hooks/useAnswerStepContent'; import { nextTick } from '@/utils/time'; diff --git a/wren-ui/src/components/pages/home/thread/feedback/ReferenceSideFloat.tsx b/wren-ui/src/components/pages/home/thread/feedback/ReferenceSideFloat.tsx new file mode 100644 index 000000000..1d077ed1c --- /dev/null +++ b/wren-ui/src/components/pages/home/thread/feedback/ReferenceSideFloat.tsx @@ -0,0 +1,213 @@ +import { useMemo, useState } from 'react'; +import clsx from 'clsx'; +import styled from 'styled-components'; +import { Tag, Typography, Button, Input } from 'antd'; +import { EditOutlined } from '@ant-design/icons'; +import { QuoteIcon } from '@/utils/icons'; +import { makeIterable } from '@/utils/iteration'; +import { ReferenceTypes, getReferenceIcon } from './utils'; + +const StyledReferenceSideFloat = styled.div` + position: relative; + width: 325px; + + .referenceSideFloat-title { + position: absolute; + top: -14px; + padding: 0 4px; + } +`; + +interface Props { + references: any[]; + saveCorrectionPrompt?: (id: string, value: string) => void; +} + +const COLLAPSE_LIMIT = 3; + +const ReferenceSummaryTemplate = ({ title, type, referenceNum }) => { + return ( +
+ + {getReferenceIcon(type)} + {referenceNum} + + + {title} + +
+ ); +}; + +const GroupReferenceTemplate = ({ + name, + type, + data, + index, + saveCorrectionPrompt, +}) => { + if (!data.length) return null; + return ( +
0 })}> + + {getReferenceIcon(type)}{' '} + {name} + + +
+ ); +}; + +const ReferenceTemplate = ({ + id, + title, + type, + referenceNum, + correctionPrompt, + saveCorrectionPrompt, +}) => { + const [isEdit, setIsEdit] = useState(false); + const [value, setValue] = useState(correctionPrompt); + const isRevise = !!correctionPrompt; + + const openEdit = () => { + setIsEdit(!isEdit); + }; + + const handleEdit = () => { + saveCorrectionPrompt(id, value); + setIsEdit(false); + }; + + return ( +
+
+ + {getReferenceIcon(type)} + {referenceNum} + +
+
+ + {title} + + {isRevise ? ( + '(revised)' + ) : ( + + )} + + + {isEdit && ( +
+ + setValue(e.target.value)} + /> + + +
+ )} +
+
+ ); +}; + +const ReferenceSummaryIterator = makeIterable(ReferenceSummaryTemplate); +const GroupReferenceIterator = makeIterable(GroupReferenceTemplate); +const ReferenceIterator = makeIterable(ReferenceTemplate); + +const References = (props: Props) => { + const { references, saveCorrectionPrompt } = props; + + const fieldReferences = references.filter( + (ref) => ref.type === ReferenceTypes.FIELD, + ); + const queryFromReferences = references.filter( + (ref) => ref.type === ReferenceTypes.QUERY_FROM, + ); + const filterReferences = references.filter( + (ref) => ref.type === ReferenceTypes.FILTER, + ); + const sortingReferences = references.filter( + (ref) => ref.type === ReferenceTypes.SORTING, + ); + const groupByReferences = references.filter( + (ref) => ref.type === ReferenceTypes.GROUP_BY, + ); + + const resources = [ + { name: 'Fields', type: ReferenceTypes.FIELD, data: fieldReferences }, + { + name: 'Query from', + type: ReferenceTypes.QUERY_FROM, + data: queryFromReferences, + }, + { name: 'Filter', type: ReferenceTypes.FILTER, data: filterReferences }, + { name: 'Sorting', type: ReferenceTypes.SORTING, data: sortingReferences }, + { + name: 'Group by', + type: ReferenceTypes.GROUP_BY, + data: groupByReferences, + }, + ]; + + return ( + + ); +}; + +export default function ReferenceSideFloat(props: Props) { + const { references } = props; + const [collapse, setCollapse] = useState(false); + + const referencesSummary = useMemo( + () => references.slice(0, COLLAPSE_LIMIT), + [collapse, references], + ); + + const handleCollapse = () => { + setCollapse(!collapse); + }; + + if (references.length === 0) return null; + return ( + +
+ References +
+ {collapse ? ( + + ) : ( + <> + + + + )} +
+ ); +} diff --git a/wren-ui/src/components/pages/home/thread/feedback/index.tsx b/wren-ui/src/components/pages/home/thread/feedback/index.tsx new file mode 100644 index 000000000..3ac8e074e --- /dev/null +++ b/wren-ui/src/components/pages/home/thread/feedback/index.tsx @@ -0,0 +1,103 @@ +import { createContext, useContext, useMemo, useState } from 'react'; +import ReferenceSideFloat from './ReferenceSideFloat'; +import { ReferenceTypes } from './utils'; + +type ContextProps = { + references: any[]; +} | null; + +export const FeedbackContext = createContext({ + references: [], +}); + +export const useFeedbackContext = () => { + return useContext(FeedbackContext); +}; + +interface Props { + headerSlot: React.ReactNode; + bodySlot: React.ReactNode; +} + +const data = [ + { + id: '1', + type: ReferenceTypes.FIELD, + title: + "Selects the 'City' column from the 'customer_data' dataset to display the city name.", + referenceNum: 1, + }, + { + id: '2', + type: ReferenceTypes.FIELD, + title: 'Reference 2', + referenceNum: 2, + }, + { + id: '3', + type: ReferenceTypes.QUERY_FROM, + title: 'Reference 3', + referenceNum: 3, + }, + { + id: '4', + type: ReferenceTypes.QUERY_FROM, + title: 'Reference 4', + referenceNum: 4, + }, + { + id: '5', + type: ReferenceTypes.FILTER, + title: 'Reference 4', + referenceNum: 4, + }, + { + id: '6', + type: ReferenceTypes.SORTING, + title: 'Reference 4', + referenceNum: 4, + }, + { + id: '7', + type: ReferenceTypes.GROUP_BY, + title: 'Reference 4', + referenceNum: 4, + }, +]; + +export default function Feedback(props: Props) { + const { headerSlot, bodySlot } = props; + + const [correctionPrompts, setCorrectionPrompts] = useState({}); + + const saveCorrectionPrompt = (id: string, value: string) => { + setCorrectionPrompts({ ...correctionPrompts, [id]: value }); + }; + + const references = useMemo(() => { + return data.map((item) => ({ + ...item, + correctionPrompt: correctionPrompts[item.id], + })); + }, [data, correctionPrompts]); + + const contextValue = { + references, + saveCorrectionPrompt, + }; + + return ( + +
{headerSlot}
+
+ {bodySlot} +
+ +
+
+
+ ); +} diff --git a/wren-ui/src/components/pages/home/thread/feedback/utils.tsx b/wren-ui/src/components/pages/home/thread/feedback/utils.tsx new file mode 100644 index 000000000..9d58373ed --- /dev/null +++ b/wren-ui/src/components/pages/home/thread/feedback/utils.tsx @@ -0,0 +1,27 @@ +import { + FilterOutlined, + SortAscendingOutlined, + GroupOutlined, +} from '@ant-design/icons'; +import { ColumnsIcon, ModelIcon } from '@/utils/icons'; + +// TODO: Replace after provided by the backend +export enum ReferenceTypes { + FIELD = 'FIELD', + QUERY_FROM = 'QUERY_FROM', + FILTER = 'FILTER', + SORTING = 'SORTING', + GROUP_BY = 'GROUP_BY', +} + +export const getReferenceIcon = (type) => { + return ( + { + [ReferenceTypes.FIELD]: , + [ReferenceTypes.QUERY_FROM]: , + [ReferenceTypes.FILTER]: , + [ReferenceTypes.SORTING]: , + [ReferenceTypes.GROUP_BY]: , + }[type] || null + ); +}; diff --git a/wren-ui/src/components/pages/home/promptThread/index.tsx b/wren-ui/src/components/pages/home/thread/index.tsx similarity index 90% rename from wren-ui/src/components/pages/home/promptThread/index.tsx rename to wren-ui/src/components/pages/home/thread/index.tsx index d066ab8ed..404de214b 100644 --- a/wren-ui/src/components/pages/home/promptThread/index.tsx +++ b/wren-ui/src/components/pages/home/thread/index.tsx @@ -13,11 +13,7 @@ interface Props { onOpenSaveAsViewModal: (data: { sql: string; responseId: number }) => void; } -const StyledPromptThread = styled.div` - width: 768px; - margin-left: auto; - margin-right: auto; - +const StyledThread = styled.div` h4.ant-typography { margin-top: 10px; } @@ -48,7 +44,7 @@ const AnswerResultTemplate = ({ const isLastThreadResponse = id === lastResponseId; return ( -
+
{index > 0 && } {error ? ( (null); @@ -105,12 +101,12 @@ export default function PromptThread(props: Props) { }, [divRef, data]); return ( - + - + ); } diff --git a/wren-ui/src/pages/home/[id].tsx b/wren-ui/src/pages/home/[id].tsx index 011fa7e6d..4f7e4f2cb 100644 --- a/wren-ui/src/pages/home/[id].tsx +++ b/wren-ui/src/pages/home/[id].tsx @@ -13,7 +13,7 @@ import { } from '@/apollo/client/graphql/home.generated'; import useAskPrompt, { getIsFinished } from '@/hooks/useAskPrompt'; import useModalAction from '@/hooks/useModalAction'; -import PromptThread from '@/components/pages/home/promptThread'; +import Thread from '@/components/pages/home/thread'; import SaveAsViewModal from '@/components/modals/SaveAsViewModal'; import { useCreateViewMutation } from '@/apollo/client/graphql/view.generated'; @@ -109,10 +109,7 @@ export default function HomeThread() { return ( - +
Date: Thu, 11 Jul 2024 17:51:47 +0800 Subject: [PATCH 02/13] Chore: add migration files and update adaptors for feedback loop feature (#506) * chore: Add migration for creating thread_response_explain table * chore: Add analysisSql method to WrenEngineAdaptor * adding explain api to ai service adapter: * fix migration errer * chore(wren-ui): Add regenerations API to wrenAIAdaptor --------- Co-authored-by: andreashimin --- ...33_create_thread_response_explain_table.js | 32 ++++ ...0711082655_update_thread_response_table.js | 22 +++ .../apollo/server/adaptors/wrenAIAdaptor.ts | 152 ++++++++++++++++-- .../server/adaptors/wrenEngineAdaptor.ts | 22 ++- 4 files changed, 204 insertions(+), 24 deletions(-) create mode 100644 wren-ui/migrations/20240711021133_create_thread_response_explain_table.js create mode 100644 wren-ui/migrations/20240711082655_update_thread_response_table.js diff --git a/wren-ui/migrations/20240711021133_create_thread_response_explain_table.js b/wren-ui/migrations/20240711021133_create_thread_response_explain_table.js new file mode 100644 index 000000000..85b47b819 --- /dev/null +++ b/wren-ui/migrations/20240711021133_create_thread_response_explain_table.js @@ -0,0 +1,32 @@ +/** + * @param { import("knex").Knex } knex + * @returns { Promise } + */ +exports.up = function (knex) { + return knex.schema.createTable('thread_response_explain', (table) => { + table.increments('id').comment('ID'); + table + .integer('thread_response_id') + .comment('Reference to thread_response.id'); + table + .foreign('thread_response_id') + .references('thread_response.id') + .onDelete('CASCADE'); + + table.string('query_id').notNullable(); + table.string('status').notNullable(); + table.jsonb('detail').notNullable(); + table.jsonb('error').notNullable(); + + // timestamps + table.timestamps(true, true); + }); +}; + +/** + * @param { import("knex").Knex } knex + * @returns { Promise } + */ +exports.down = function (knex) { + return knex.schema.dropTable('thread_response_explain'); +}; diff --git a/wren-ui/migrations/20240711082655_update_thread_response_table.js b/wren-ui/migrations/20240711082655_update_thread_response_table.js new file mode 100644 index 000000000..8cf1fa33b --- /dev/null +++ b/wren-ui/migrations/20240711082655_update_thread_response_table.js @@ -0,0 +1,22 @@ +/** + * @param { import("knex").Knex } knex + * @returns { Promise } + */ +exports.up = function (knex) { + return knex.schema.alterTable('thread_response', (table) => { + table + .jsonb('corrections') + .nullable() + .comment('the corrections of the previous thread response'); // [{type, id, correct}, ...] + }); +}; + +/** + * @param { import("knex").Knex } knex + * @returns { Promise } + */ +exports.down = function (knex) { + return knex.schema.alterTable('thread_response', (table) => { + table.dropColumn('corrections'); + }); +}; diff --git a/wren-ui/src/apollo/server/adaptors/wrenAIAdaptor.ts b/wren-ui/src/apollo/server/adaptors/wrenAIAdaptor.ts index 507b11d92..1e9a3c47a 100644 --- a/wren-ui/src/apollo/server/adaptors/wrenAIAdaptor.ts +++ b/wren-ui/src/apollo/server/adaptors/wrenAIAdaptor.ts @@ -55,6 +55,15 @@ export interface AsyncQueryResponse { queryId: string; } +export type ExplainResult = AIServiceResponse; + +export enum ExplainPipelineStatus { + UNDERSTANDING = 'UNDERSTANDING', + GENERATING = 'GENERATING', + FINISHED = 'FINISHED', + FAILED = 'FAILED', +} + export enum AskResultStatus { UNDERSTANDING = 'UNDERSTANDING', SEARCHING = 'SEARCHING', @@ -71,7 +80,21 @@ export enum AskCandidateType { LLM = 'LLM', } -export interface AskResponse { +export enum ExplainType { + FILTER = 'filter', + SELECT_ITEMS = 'selectItems', + RELATION = 'relation', + GROUP_BY_KEYS = 'groupByKeys', + SORTINGS = 'sortings', +} + +// UI currently only support nl_expression +export enum ExpressionType { + SQL_EXPRESSION = 'sql_expression', + NL_EXPRESSION = 'nl_expression', +} + +export interface AIServiceResponse { status: S; response: R | null; error: WrenAIError | null; @@ -83,7 +106,7 @@ export interface AskDetailInput { summary: string; } -export type AskDetailResult = AskResponse< +export type AskDetailResult = AIServiceResponse< { description: string; steps: AskStep[]; @@ -91,7 +114,7 @@ export type AskDetailResult = AskResponse< AskResultStatus >; -export type AskResult = AskResponse< +export type AskResult = AIServiceResponse< Array<{ type: AskCandidateType; sql: string; @@ -101,7 +124,26 @@ export type AskResult = AskResponse< AskResultStatus >; -const getAISerciceError = (error: any) => { +export interface CorrectionObject { + type: T; + value: string; +} + +export interface AskCorrection { + before: CorrectionObject; + after: CorrectionObject; +} + +export interface AskStepWithCorrections extends AskStep { + corrections: AskCorrection[]; +} + +export interface RegenerateAskDetailInput { + description: string; + steps: AskStepWithCorrections[]; +} + +const getAIServiceError = (error: any) => { const { data } = error.response || {}; return data?.detail ? `${error.message}, detail: ${data.detail}` @@ -129,6 +171,12 @@ export interface IWrenAIAdaptor { */ generateAskDetail(input: AskDetailInput): Promise; getAskDetailResult(queryId: string): Promise; + explain(question: string, analysisResults: any): Promise; + getExplainResult(queryId: string): Promise; + regenerateAskDetail( + input: RegenerateAskDetailInput, + ): Promise; + getRegeneratedAskDetailResult(queryId: string): Promise; } export class WrenAIAdaptor implements IWrenAIAdaptor { @@ -148,11 +196,11 @@ export class WrenAIAdaptor implements IWrenAIAdaptor { const res = await axios.post(`${this.wrenAIBaseEndpoint}/v1/asks`, { query: input.query, id: input.deployId, - history: this.transfromHistoryInput(input.history), + history: this.transformHistoryInput(input.history), }); return { queryId: res.data.query_id }; } catch (err: any) { - logger.debug(`Got error when asking wren AI: ${getAISerciceError(err)}`); + logger.debug(`Got error when asking wren AI: ${getAIServiceError(err)}`); throw err; } } @@ -164,7 +212,7 @@ export class WrenAIAdaptor implements IWrenAIAdaptor { status: 'stopped', }); } catch (err: any) { - logger.debug(`Got error when canceling ask: ${getAISerciceError(err)}`); + logger.debug(`Got error when canceling ask: ${getAIServiceError(err)}`); throw err; } } @@ -178,7 +226,7 @@ export class WrenAIAdaptor implements IWrenAIAdaptor { return this.transformAskResult(res.data); } catch (err: any) { logger.debug( - `Got error when getting ask result: ${getAISerciceError(err)}`, + `Got error when getting ask result: ${getAIServiceError(err)}`, ); // throw err; throw Errors.create(Errors.GeneralErrorCodes.INTERNAL_SERVER_ERROR, { @@ -187,6 +235,44 @@ export class WrenAIAdaptor implements IWrenAIAdaptor { } } + public async explain( + question: string, + analysisResults: any, + ): Promise { + try { + const res = await axios.post( + `${this.wrenAIBaseEndpoint}/v1/sql-explanations`, + { + question, + steps_with_analysis_results: analysisResults, + }, + ); + return { queryId: res.data.query_id }; + } catch (err: any) { + logger.debug(`Got error when explaining: ${getAIServiceError(err)}`); + throw err; + } + } + + public async getExplainResult(queryId: string): Promise { + // make GET request /v1/sql-explanations/:query_id/result to get the result + try { + const res = await axios.get( + `${this.wrenAIBaseEndpoint}/v1/sql-explanations/${queryId}/result`, + ); + return { + status: res.data.status as ExplainPipelineStatus, + response: res.data.response, + error: this.transformStatusAndError(res.data).error, + }; + } catch (err: any) { + logger.debug( + `Got error when getting explain result: ${getAIServiceError(err)}`, + ); + throw err; + } + } + /** * After you choose a candidate, you can request AI service to generate the detail. */ @@ -202,7 +288,7 @@ export class WrenAIAdaptor implements IWrenAIAdaptor { return { queryId: res.data.query_id }; } catch (err: any) { logger.debug( - `Got error when generating ask detail: ${getAISerciceError(err)}`, + `Got error when generating ask detail: ${getAIServiceError(err)}`, ); throw err; } @@ -217,7 +303,37 @@ export class WrenAIAdaptor implements IWrenAIAdaptor { return this.transformAskDetailResult(res.data); } catch (err: any) { logger.debug( - `Got error when getting ask detail result: ${getAISerciceError(err)}`, + `Got error when getting ask detail result: ${getAIServiceError(err)}`, + ); + throw err; + } + } + + public async regenerateAskDetail(input: RegenerateAskDetailInput) { + try { + const res = await axios.post( + `${this.wrenAIBaseEndpoint}/v1/sql-regenerations`, + input, + ); + return { queryId: res.data.query_id }; + } catch (err: any) { + logger.debug( + `Got error when regenerating ask detail: ${getAIServiceError(err)}`, + ); + throw err; + } + } + + public async getRegeneratedAskDetailResult(queryId: string) { + // make GET request /v1/sql-regenerations/:query_id/result to get the result + try { + const res = await axios.get( + `${this.wrenAIBaseEndpoint}/v1/sql-regenerations/${queryId}/result`, + ); + return this.transformAskDetailResult(res.data); + } catch (err: any) { + logger.debug( + `Got error when getting regenerated ask detail result: ${getAIServiceError(err)}`, ); throw err; } @@ -309,7 +425,7 @@ export class WrenAIAdaptor implements IWrenAIAdaptor { })); return { - status, + status: status as AskResultStatus, error, response: candidates, }; @@ -326,7 +442,7 @@ export class WrenAIAdaptor implements IWrenAIAdaptor { })); return { - status, + status: status as AskResultStatus, error, response: { description: body?.response?.description, @@ -336,7 +452,7 @@ export class WrenAIAdaptor implements IWrenAIAdaptor { } private transformStatusAndError(body: any): { - status: AskResultStatus; + status: AskResultStatus | ExplainPipelineStatus; error?: { code: Errors.GeneralErrorCodes; message: string; @@ -344,9 +460,11 @@ export class WrenAIAdaptor implements IWrenAIAdaptor { } | null; } { // transform status to enum - const status = AskResultStatus[ - body?.status?.toUpperCase() - ] as AskResultStatus; + const status = + (AskResultStatus[body?.status?.toUpperCase()] as AskResultStatus) || + (ExplainPipelineStatus[ + body.status + ]?.toUpperCase() as ExplainPipelineStatus); if (!status) { throw new Error(`Unknown ask status: ${body?.status}`); @@ -380,7 +498,7 @@ export class WrenAIAdaptor implements IWrenAIAdaptor { }; } - private transfromHistoryInput(history: AskHistory) { + private transformHistoryInput(history: AskHistory) { if (!history) { return null; } diff --git a/wren-ui/src/apollo/server/adaptors/wrenEngineAdaptor.ts b/wren-ui/src/apollo/server/adaptors/wrenEngineAdaptor.ts index 8fb0a47e8..d8c85eec4 100644 --- a/wren-ui/src/apollo/server/adaptors/wrenEngineAdaptor.ts +++ b/wren-ui/src/apollo/server/adaptors/wrenEngineAdaptor.ts @@ -94,6 +94,9 @@ export interface IWrenEngineAdaptor { sql: string, options: WrenEngineDryRunOption, ): Promise; + + // analysis + analysisSql(sql: string, mdl: Manifest): Promise; } export class WrenEngineAdaptor implements IWrenEngineAdaptor { @@ -105,6 +108,7 @@ export class WrenEngineAdaptor implements IWrenEngineAdaptor { private dryPlanUrlPath = '/v1/mdl/dry-plan'; private dryRunUrlPath = '/v1/mdl/dry-run'; private validateUrlPath = '/v1/mdl/validate'; + private analysisUrlPath = '/v1/analysis/sql'; constructor({ wrenEngineEndpoint }: { wrenEngineEndpoint: string }) { this.wrenEngineBaseEndpoint = wrenEngineEndpoint; @@ -315,16 +319,20 @@ export class WrenEngineAdaptor implements IWrenEngineAdaptor { } } - private async getDeployStatus(): Promise { + public async analysisSql(sql: string, mdl: Manifest) { try { - const res = await axios.get( - `${this.wrenEngineBaseEndpoint}/v1/mdl/status`, + const url = new URL(this.analysisUrlPath, this.wrenEngineBaseEndpoint); + const headers = { + 'Content-Type': 'application/json', + }; + const res = await axios.post( + url.href, + { sql, manifest: mdl }, + { headers }, ); - return res.data as WrenEngineDeployStatusResponse; + return res.data; } catch (err: any) { - logger.debug( - `WrenEngine: Got error when getting deploy status: ${err.message}`, - ); + logger.debug(`Got error when analyzing sql: ${err.message}`); throw err; } } From a0748db7e278ba60fd4353f6b56a445cea0bbd53 Mon Sep 17 00:00:00 2001 From: Shimin Date: Fri, 12 Jul 2024 10:35:17 +0800 Subject: [PATCH 03/13] feat(wren-ui): Add Pending Feedbacks UI (#498) * feat(wren-ui): adjust ReferenceSideFloat detailed styles * feat(wren-ui): add AdjustmentSideFloat & ReviewDrawer * feat(wren-ui): change wording to feedback * feat(wren-ui): rename component & filename --- .../thread/feedback/FeedbackSideFloat.tsx | 63 +++++++ .../thread/feedback/ReferenceSideFloat.tsx | 39 ++-- .../home/thread/feedback/ReviewDrawer.tsx | 167 ++++++++++++++++++ .../pages/home/thread/feedback/index.tsx | 35 +++- .../components/pages/home/thread/index.tsx | 10 +- wren-ui/src/styles/utilities/spacing.less | 5 + wren-ui/src/styles/utilities/text.less | 4 + 7 files changed, 304 insertions(+), 19 deletions(-) create mode 100644 wren-ui/src/components/pages/home/thread/feedback/FeedbackSideFloat.tsx create mode 100644 wren-ui/src/components/pages/home/thread/feedback/ReviewDrawer.tsx diff --git a/wren-ui/src/components/pages/home/thread/feedback/FeedbackSideFloat.tsx b/wren-ui/src/components/pages/home/thread/feedback/FeedbackSideFloat.tsx new file mode 100644 index 000000000..ba52bac84 --- /dev/null +++ b/wren-ui/src/components/pages/home/thread/feedback/FeedbackSideFloat.tsx @@ -0,0 +1,63 @@ +import clsx from 'clsx'; +import { useMemo } from 'react'; +import styled from 'styled-components'; +import { Button, Popconfirm } from 'antd'; +import { FileTextOutlined } from '@ant-design/icons'; + +const StyledFeedbackSideFloat = styled.div` + position: relative; + width: 325px; + + .feedbackSideFloat-title { + position: absolute; + top: -14px; + padding: 0 4px; + } +`; + +interface Props { + className?: string; + references: any[]; + onOpenReviewDrawer: () => void; + onResetAllChanges: () => void; +} + +export default function FeedbackSideFloat(props: Props) { + const { className, references, onOpenReviewDrawer, onResetAllChanges } = + props; + + const changedReferences = useMemo(() => { + return (references || []).filter((item) => !!item.correctionPrompt); + }, [references]); + + if (changedReferences.length === 0) return null; + return ( + +
+ Pending feedbacks +
+
+ + + + +
+
+ ); +} diff --git a/wren-ui/src/components/pages/home/thread/feedback/ReferenceSideFloat.tsx b/wren-ui/src/components/pages/home/thread/feedback/ReferenceSideFloat.tsx index 1d077ed1c..851f2ad6b 100644 --- a/wren-ui/src/components/pages/home/thread/feedback/ReferenceSideFloat.tsx +++ b/wren-ui/src/components/pages/home/thread/feedback/ReferenceSideFloat.tsx @@ -1,5 +1,5 @@ -import { useMemo, useState } from 'react'; import clsx from 'clsx'; +import { useMemo, useState } from 'react'; import styled from 'styled-components'; import { Tag, Typography, Button, Input } from 'antd'; import { EditOutlined } from '@ant-design/icons'; @@ -9,7 +9,7 @@ import { ReferenceTypes, getReferenceIcon } from './utils'; const StyledReferenceSideFloat = styled.div` position: relative; - width: 325px; + width: 330px; .referenceSideFloat-title { position: absolute; @@ -20,15 +20,21 @@ const StyledReferenceSideFloat = styled.div` interface Props { references: any[]; - saveCorrectionPrompt?: (id: string, value: string) => void; + onSaveCorrectionPrompt?: (id: string, value: string) => void; } const COLLAPSE_LIMIT = 3; -const ReferenceSummaryTemplate = ({ title, type, referenceNum }) => { +const ReferenceSummaryTemplate = ({ + title, + type, + referenceNum, + correctionPrompt, +}) => { + const isRevise = !!correctionPrompt; return (
- + {getReferenceIcon(type)} {referenceNum} @@ -80,6 +86,7 @@ const ReferenceTemplate = ({ const handleEdit = () => { saveCorrectionPrompt(id, value); setIsEdit(false); + setValue(''); }; return ( @@ -95,7 +102,7 @@ const ReferenceTemplate = ({ {title} {isRevise ? ( - '(revised)' + '(feedback suggested)' ) : ( )} @@ -107,9 +114,10 @@ const ReferenceTemplate = ({ setValue(e.target.value)} + onPressEnter={handleEdit} /> + + + } + > + + + ); +} diff --git a/wren-ui/src/components/pages/home/thread/feedback/index.tsx b/wren-ui/src/components/pages/home/thread/feedback/index.tsx index 3ac8e074e..e7671c28d 100644 --- a/wren-ui/src/components/pages/home/thread/feedback/index.tsx +++ b/wren-ui/src/components/pages/home/thread/feedback/index.tsx @@ -1,5 +1,8 @@ import { createContext, useContext, useMemo, useState } from 'react'; -import ReferenceSideFloat from './ReferenceSideFloat'; +import ReferenceSideFloat from '@/components/pages/home/thread/feedback/ReferenceSideFloat'; +import FeedbackSideFloat from '@/components/pages/home/thread/feedback/FeedbackSideFloat'; +import ReviewDrawer from '@/components/pages/home/thread/feedback/ReviewDrawer'; +import useDrawerAction from '@/hooks/useDrawerAction'; import { ReferenceTypes } from './utils'; type ContextProps = { @@ -69,11 +72,20 @@ export default function Feedback(props: Props) { const { headerSlot, bodySlot } = props; const [correctionPrompts, setCorrectionPrompts] = useState({}); + const reviewDrawer = useDrawerAction(); const saveCorrectionPrompt = (id: string, value: string) => { setCorrectionPrompts({ ...correctionPrompts, [id]: value }); }; + const removeCorrectionPrompt = (id: string) => { + setCorrectionPrompts({ ...correctionPrompts, [id]: undefined }); + }; + + const resetAllCorrectionPrompts = () => { + setCorrectionPrompts({}); + }; + const references = useMemo(() => { return data.map((item) => ({ ...item, @@ -88,16 +100,33 @@ export default function Feedback(props: Props) { return ( -
{headerSlot}
+
+ {headerSlot} +
+ +
+
{bodySlot}
+
); } diff --git a/wren-ui/src/components/pages/home/thread/index.tsx b/wren-ui/src/components/pages/home/thread/index.tsx index 404de214b..0a3b5a453 100644 --- a/wren-ui/src/components/pages/home/thread/index.tsx +++ b/wren-ui/src/components/pages/home/thread/index.tsx @@ -28,6 +28,10 @@ const StyledThread = styled.div` } `; +const StyledContainer = styled.div` + max-width: 1030px; +`; + const AnswerResultTemplate = ({ index, id, @@ -44,7 +48,7 @@ const AnswerResultTemplate = ({ const isLastThreadResponse = id === lastResponseId; return ( -
+ {index > 0 && } {error ? ( )} -
+ ); }; @@ -101,7 +105,7 @@ export default function Thread(props: Props) { }, [divRef, data]); return ( - + Date: Fri, 12 Jul 2024 17:19:39 +0800 Subject: [PATCH 04/13] Chore: move background tracker for thread responses (#510) * Chore: move background tracker for thread responses * use abstract class * rm comment --- .../apollo/server/backgroundTrackers/index.ts | 12 ++ .../threadResponseBackgroundTracker.ts | 112 ++++++++++++++++++ .../apollo/server/services/askingService.ts | 111 +---------------- 3 files changed, 130 insertions(+), 105 deletions(-) create mode 100644 wren-ui/src/apollo/server/backgroundTrackers/index.ts create mode 100644 wren-ui/src/apollo/server/backgroundTrackers/threadResponseBackgroundTracker.ts diff --git a/wren-ui/src/apollo/server/backgroundTrackers/index.ts b/wren-ui/src/apollo/server/backgroundTrackers/index.ts new file mode 100644 index 000000000..294fb59e4 --- /dev/null +++ b/wren-ui/src/apollo/server/backgroundTrackers/index.ts @@ -0,0 +1,12 @@ +import { Telemetry } from '../telemetry/telemetry'; + +export abstract class BackgroundTracker { + protected tasks: Record = {}; + protected intervalTime: number = 1000; + protected runningJobs: Set = new Set(); + protected telemetry: Telemetry; + + public abstract start(): void; + public abstract addTask(task: R): void; + public abstract getTasks(): Record; +} diff --git a/wren-ui/src/apollo/server/backgroundTrackers/threadResponseBackgroundTracker.ts b/wren-ui/src/apollo/server/backgroundTrackers/threadResponseBackgroundTracker.ts new file mode 100644 index 000000000..1738f7d80 --- /dev/null +++ b/wren-ui/src/apollo/server/backgroundTrackers/threadResponseBackgroundTracker.ts @@ -0,0 +1,112 @@ +import { AskResultStatus, IWrenAIAdaptor } from '../adaptors/wrenAIAdaptor'; +import { + IThreadResponseRepository, + ThreadResponse, +} from '../repositories/threadResponseRepository'; +import { Telemetry } from '../telemetry/telemetry'; +import { BackgroundTracker } from './index'; +import { getLogger } from '@server/utils/logger'; + +const logger = getLogger('ThreadResponseBackgroundTracker'); +logger.level = 'debug'; + +export class ThreadResponseBackgroundTracker extends BackgroundTracker { + // tasks is a kv pair of task id and thread response + private wrenAIAdaptor: IWrenAIAdaptor; + private threadResponseRepository: IThreadResponseRepository; + + constructor({ + telemetry, + wrenAIAdaptor, + threadResponseRepository, + }: { + telemetry: Telemetry; + wrenAIAdaptor: IWrenAIAdaptor; + threadResponseRepository: IThreadResponseRepository; + }) { + super(); + this.telemetry = telemetry; + this.wrenAIAdaptor = wrenAIAdaptor; + this.threadResponseRepository = threadResponseRepository; + this.start(); + } + + public start() { + logger.info('Background tracker started'); + setInterval(() => { + const jobs = Object.values(this.tasks).map( + (threadResponse) => async () => { + // check if same job is running + if (this.runningJobs.has(threadResponse.id)) { + return; + } + + // mark the job as running + this.runningJobs.add(threadResponse.id); + + // get the latest result from AI service + const result = await this.wrenAIAdaptor.getAskDetailResult( + threadResponse.queryId, + ); + + // check if status change + if (threadResponse.status === result.status) { + // mark the job as finished + logger.debug( + `Job ${threadResponse.id} status not changed, finished`, + ); + this.runningJobs.delete(threadResponse.id); + return; + } + + // update database + logger.debug(`Job ${threadResponse.id} status changed, updating`); + await this.threadResponseRepository.updateOne(threadResponse.id, { + status: result.status, + detail: result.response, + error: result.error, + }); + + // remove the task from tracker if it is finalized + if (this.isFinalized(result.status)) { + this.telemetry.send_event('question_answered', { + question: threadResponse.question, + result, + }); + logger.debug(`Job ${threadResponse.id} is finalized, removing`); + delete this.tasks[threadResponse.id]; + } + + // mark the job as finished + this.runningJobs.delete(threadResponse.id); + }, + ); + + // run the jobs + Promise.allSettled(jobs.map((job) => job())).then((results) => { + // show reason of rejection + results.forEach((result, index) => { + if (result.status === 'rejected') { + logger.error(`Job ${index} failed: ${result.reason}`); + } + }); + }); + }, this.intervalTime); + } + + public addTask(threadResponse: ThreadResponse) { + this.tasks[threadResponse.id] = threadResponse; + } + + public getTasks() { + return this.tasks; + } + + private isFinalized = (status: AskResultStatus) => { + return ( + status === AskResultStatus.FAILED || + status === AskResultStatus.FINISHED || + status === AskResultStatus.STOPPED + ); + }; +} diff --git a/wren-ui/src/apollo/server/services/askingService.ts b/wren-ui/src/apollo/server/services/askingService.ts index 79d944bdc..4ec533ea5 100644 --- a/wren-ui/src/apollo/server/services/askingService.ts +++ b/wren-ui/src/apollo/server/services/askingService.ts @@ -18,6 +18,7 @@ import { format } from 'sql-formatter'; import { Telemetry } from '../telemetry/telemetry'; import { IViewRepository, View } from '../repositories'; import { IQueryService, PreviewDataResponse } from './queryService'; +import { ThreadResponseBackgroundTracker } from '../backgroundTrackers/threadResponseBackgroundTracker'; const logger = getLogger('AskingService'); logger.level = 'debug'; @@ -132,106 +133,6 @@ export const constructCteSql = ( return sql; }; -/** - * Background tracker to track the status of the asking detail task - */ -class BackgroundTracker { - // tasks is a kv pair of task id and thread response - private tasks: Record = {}; - private intervalTime: number; - private wrenAIAdaptor: IWrenAIAdaptor; - private threadResponseRepository: IThreadResponseRepository; - private runningJobs = new Set(); - private telemetry: Telemetry; - - constructor({ - telemetry, - wrenAIAdaptor, - threadResponseRepository, - }: { - telemetry: Telemetry; - wrenAIAdaptor: IWrenAIAdaptor; - threadResponseRepository: IThreadResponseRepository; - }) { - this.telemetry = telemetry; - this.wrenAIAdaptor = wrenAIAdaptor; - this.threadResponseRepository = threadResponseRepository; - this.intervalTime = 1000; - this.start(); - } - - public start() { - logger.info('Background tracker started'); - setInterval(() => { - const jobs = Object.values(this.tasks).map( - (threadResponse) => async () => { - // check if same job is running - if (this.runningJobs.has(threadResponse.id)) { - return; - } - - // mark the job as running - this.runningJobs.add(threadResponse.id); - - // get the latest result from AI service - const result = await this.wrenAIAdaptor.getAskDetailResult( - threadResponse.queryId, - ); - - // check if status change - if (threadResponse.status === result.status) { - // mark the job as finished - logger.debug( - `Job ${threadResponse.id} status not changed, finished`, - ); - this.runningJobs.delete(threadResponse.id); - return; - } - - // update database - logger.debug(`Job ${threadResponse.id} status changed, updating`); - await this.threadResponseRepository.updateOne(threadResponse.id, { - status: result.status, - detail: result.response, - error: result.error, - }); - - // remove the task from tracker if it is finalized - if (isFinalized(result.status)) { - this.telemetry.send_event('question_answered', { - question: threadResponse.question, - result, - }); - logger.debug(`Job ${threadResponse.id} is finalized, removing`); - delete this.tasks[threadResponse.id]; - } - - // mark the job as finished - this.runningJobs.delete(threadResponse.id); - }, - ); - - // run the jobs - Promise.allSettled(jobs.map((job) => job())).then((results) => { - // show reason of rejection - results.forEach((result, index) => { - if (result.status === 'rejected') { - logger.error(`Job ${index} failed: ${result.reason}`); - } - }); - }); - }, this.intervalTime); - } - - public addTask(threadResponse: ThreadResponse) { - this.tasks[threadResponse.id] = threadResponse; - } - - public getTasks() { - return this.tasks; - } -} - export class AskingService implements IAskingService { private wrenAIAdaptor: IWrenAIAdaptor; private deployService: IDeployService; @@ -239,7 +140,7 @@ export class AskingService implements IAskingService { private viewRepository: IViewRepository; private threadRepository: IThreadRepository; private threadResponseRepository: IThreadResponseRepository; - private backgroundTracker: BackgroundTracker; + private backgroundTracker: ThreadResponseBackgroundTracker; private queryService: IQueryService; private telemetry: Telemetry; @@ -270,7 +171,7 @@ export class AskingService implements IAskingService { this.threadResponseRepository = threadResponseRepository; this.telemetry = telemetry; this.queryService = queryService; - this.backgroundTracker = new BackgroundTracker({ + this.backgroundTracker = new ThreadResponseBackgroundTracker({ telemetry, wrenAIAdaptor, threadResponseRepository, @@ -281,14 +182,14 @@ export class AskingService implements IAskingService { // list thread responses from database // filter status not finalized and put them into background tracker const threadResponses = await this.threadResponseRepository.findAll(); - const unfininshedThreadResponses = threadResponses.filter( + const unfinishedThreadResponses = threadResponses.filter( (threadResponse) => !isFinalized(threadResponse.status as AskResultStatus), ); logger.info( - `Initialization: adding unfininshed thread responses (total: ${unfininshedThreadResponses.length}) to background tracker`, + `Initialization: adding unfinished thread responses (total: ${unfinishedThreadResponses.length}) to background tracker`, ); - for (const threadResponse of unfininshedThreadResponses) { + for (const threadResponse of unfinishedThreadResponses) { this.backgroundTracker.addTask(threadResponse); } } From 76d573476fcacfb0253fc7dc29841601e1810a41 Mon Sep 17 00:00:00 2001 From: Shimin Date: Fri, 19 Jul 2024 15:19:01 +0800 Subject: [PATCH 05/13] feat(wren-ui): Create Regenerated Response API & Add Regenerated Information in Thread Detail UI (#527) * feat(wren-ui): add regeneration API * feat(wren-ui): make strategy in thread response background tracker * feat(wren-ui): rename enum & add comment in AI adapter * feat(wren-ui): implement RegeneratedThreadResponse graphQL API * feat(wren-ui): modify getResponsesWithThread in repository & provide corrections in getThread resolver * feat(wren-ui): generate createRegeneratedThreadResponse graphql API * feat(wren-ui): implement submit review drawer flow * feat(wren-ui): refine type issues in UI * feat(wren-ui): add underline & select-none utilities class * feat(wren-ui): add regenerated answer information on UI * feat(wren-ui): refine review drawer submit flow * feat(wren-ui): refine error handler when submit review * feat(wren-ui): rename mutation to CreateCorrectedThreadResponse * fix(wren-ui): remove thread existence check --- .../src/apollo/client/graphql/__types__.ts | 36 +++++ .../apollo/client/graphql/home.generated.ts | 69 +++++++-- wren-ui/src/apollo/client/graphql/home.ts | 24 ++- .../apollo/server/adaptors/wrenAIAdaptor.ts | 25 ++-- .../threadResponseBackgroundTracker.ts | 51 ++++++- .../repositories/threadResponseRepository.ts | 60 +++++--- wren-ui/src/apollo/server/resolvers.ts | 1 + .../apollo/server/resolvers/askingResolver.ts | 31 ++++ wren-ui/src/apollo/server/schema.ts | 32 ++++ .../apollo/server/services/askingService.ts | 89 ++++++++++- .../pages/home/thread/AnswerResult.tsx | 140 ++++++++++++------ .../thread/feedback/FeedbackSideFloat.tsx | 15 +- .../thread/feedback/ReferenceSideFloat.tsx | 37 ++--- .../home/thread/feedback/ReviewDrawer.tsx | 42 ++++-- .../pages/home/thread/feedback/index.tsx | 67 ++++++--- .../pages/home/thread/feedback/utils.tsx | 26 ++-- .../components/pages/home/thread/index.tsx | 43 +++--- wren-ui/src/pages/home/[id].tsx | 58 ++++++-- wren-ui/src/styles/utilities/display.less | 4 + wren-ui/src/styles/utilities/text.less | 8 + wren-ui/src/utils/errorHandler.tsx | 13 ++ 21 files changed, 669 insertions(+), 202 deletions(-) diff --git a/wren-ui/src/apollo/client/graphql/__types__.ts b/wren-ui/src/apollo/client/graphql/__types__.ts index 9310c0d43..afdb92ca9 100644 --- a/wren-ui/src/apollo/client/graphql/__types__.ts +++ b/wren-ui/src/apollo/client/graphql/__types__.ts @@ -71,6 +71,13 @@ export type ConnectionInfo = { username?: Maybe; }; +export type CorrectionDetail = { + __typename?: 'CorrectionDetail'; + correction: Scalars['String']; + id: Scalars['Int']; + type: ReferenceType; +}; + export type CreateCalculatedFieldInput = { expression: ExpressionName; lineage: Array; @@ -78,6 +85,11 @@ export type CreateCalculatedFieldInput = { name: Scalars['String']; }; +export type CreateCorrectedThreadResponseInput = { + corrections: Array; + responseId: Scalars['Int']; +}; + export type CreateModelInput = { fields: Array; primaryKey?: InputMaybe; @@ -104,6 +116,14 @@ export type CreateThreadInput = { viewId?: InputMaybe; }; +export type CreateThreadResponseCorrectionInput = { + correction: Scalars['String']; + id: Scalars['Int']; + reference: Scalars['String']; + stepIndex: Scalars['Int']; + type: ReferenceType; +}; + export type CreateThreadResponseInput = { question?: InputMaybe; sql?: InputMaybe; @@ -399,6 +419,7 @@ export type Mutation = { cancelAskingTask: Scalars['Boolean']; createAskingTask: Task; createCalculatedField: Scalars['JSON']; + createCorrectedThreadResponse: ThreadResponse; createModel: Scalars['JSON']; createRelation: Scalars['JSON']; createThread: Thread; @@ -448,6 +469,12 @@ export type MutationCreateCalculatedFieldArgs = { }; +export type MutationCreateCorrectedThreadResponseArgs = { + data: CreateCorrectedThreadResponseInput; + threadId: Scalars['Int']; +}; + + export type MutationCreateModelArgs = { data: CreateModelInput; }; @@ -704,6 +731,14 @@ export type RecommendRelations = { relations: Array>; }; +export enum ReferenceType { + FIELD = 'FIELD', + FILTER = 'FILTER', + GROUP_BY = 'GROUP_BY', + QUERY_FROM = 'QUERY_FROM', + SORTING = 'SORTING' +} + export type Relation = { __typename?: 'Relation'; fromColumnId: Scalars['Int']; @@ -827,6 +862,7 @@ export type Thread = { export type ThreadResponse = { __typename?: 'ThreadResponse'; + corrections?: Maybe>; detail?: Maybe; error?: Maybe; id: Scalars['Int']; diff --git a/wren-ui/src/apollo/client/graphql/home.generated.ts b/wren-ui/src/apollo/client/graphql/home.generated.ts index 483a252a3..e8a62b88b 100644 --- a/wren-ui/src/apollo/client/graphql/home.generated.ts +++ b/wren-ui/src/apollo/client/graphql/home.generated.ts @@ -5,7 +5,7 @@ import * as Apollo from '@apollo/client'; const defaultOptions = {} as const; export type CommonErrorFragment = { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null }; -export type CommonResponseFragment = { __typename?: 'ThreadResponse', id: number, question: string, summary: string, status: Types.AskingTaskStatus, detail?: { __typename?: 'ThreadResponseDetail', sql?: string | null, description?: string | null, steps: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }>, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null } | null }; +export type CommonResponseFragment = { __typename?: 'ThreadResponse', id: number, question: string, summary: string, status: Types.AskingTaskStatus, detail?: { __typename?: 'ThreadResponseDetail', sql?: string | null, description?: string | null, steps: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }>, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null } | null, corrections?: Array<{ __typename?: 'CorrectionDetail', id: number, type: Types.ReferenceType, correction: string }> | null }; export type SuggestedQuestionsQueryVariables = Types.Exact<{ [key: string]: never; }>; @@ -29,14 +29,14 @@ export type ThreadQueryVariables = Types.Exact<{ }>; -export type ThreadQuery = { __typename?: 'Query', thread: { __typename?: 'DetailedThread', id: number, sql: string, summary: string, responses: Array<{ __typename?: 'ThreadResponse', id: number, question: string, summary: string, status: Types.AskingTaskStatus, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null, detail?: { __typename?: 'ThreadResponseDetail', sql?: string | null, description?: string | null, steps: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }>, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null } | null }> } }; +export type ThreadQuery = { __typename?: 'Query', thread: { __typename?: 'DetailedThread', id: number, sql: string, summary: string, responses: Array<{ __typename?: 'ThreadResponse', id: number, question: string, summary: string, status: Types.AskingTaskStatus, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null, detail?: { __typename?: 'ThreadResponseDetail', sql?: string | null, description?: string | null, steps: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }>, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null } | null, corrections?: Array<{ __typename?: 'CorrectionDetail', id: number, type: Types.ReferenceType, correction: string }> | null }> } }; export type ThreadResponseQueryVariables = Types.Exact<{ responseId: Types.Scalars['Int']; }>; -export type ThreadResponseQuery = { __typename?: 'Query', threadResponse: { __typename?: 'ThreadResponse', id: number, question: string, summary: string, status: Types.AskingTaskStatus, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null, detail?: { __typename?: 'ThreadResponseDetail', sql?: string | null, description?: string | null, steps: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }>, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null } | null } }; +export type ThreadResponseQuery = { __typename?: 'Query', threadResponse: { __typename?: 'ThreadResponse', id: number, question: string, summary: string, status: Types.AskingTaskStatus, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null, detail?: { __typename?: 'ThreadResponseDetail', sql?: string | null, description?: string | null, steps: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }>, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null } | null, corrections?: Array<{ __typename?: 'CorrectionDetail', id: number, type: Types.ReferenceType, correction: string }> | null } }; export type CreateAskingTaskMutationVariables = Types.Exact<{ data: Types.AskingTaskInput; @@ -65,7 +65,7 @@ export type CreateThreadResponseMutationVariables = Types.Exact<{ }>; -export type CreateThreadResponseMutation = { __typename?: 'Mutation', createThreadResponse: { __typename?: 'ThreadResponse', id: number, question: string, summary: string, status: Types.AskingTaskStatus, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null, detail?: { __typename?: 'ThreadResponseDetail', sql?: string | null, description?: string | null, steps: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }>, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null } | null } }; +export type CreateThreadResponseMutation = { __typename?: 'Mutation', createThreadResponse: { __typename?: 'ThreadResponse', id: number, question: string, summary: string, status: Types.AskingTaskStatus, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null, detail?: { __typename?: 'ThreadResponseDetail', sql?: string | null, description?: string | null, steps: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }>, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null } | null, corrections?: Array<{ __typename?: 'CorrectionDetail', id: number, type: Types.ReferenceType, correction: string }> | null } }; export type UpdateThreadMutationVariables = Types.Exact<{ where: Types.ThreadUniqueWhereInput; @@ -96,6 +96,14 @@ export type GetNativeSqlQueryVariables = Types.Exact<{ export type GetNativeSqlQuery = { __typename?: 'Query', nativeSql: string }; +export type CreateCorrectedThreadResponseMutationVariables = Types.Exact<{ + threadId: Types.Scalars['Int']; + data: Types.CreateCorrectedThreadResponseInput; +}>; + + +export type CreateCorrectedThreadResponseMutation = { __typename?: 'Mutation', createCorrectedThreadResponse: { __typename?: 'ThreadResponse', id: number, question: string, summary: string, status: Types.AskingTaskStatus, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null, detail?: { __typename?: 'ThreadResponseDetail', sql?: string | null, description?: string | null, steps: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }>, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null } | null, corrections?: Array<{ __typename?: 'CorrectionDetail', id: number, type: Types.ReferenceType, correction: string }> | null } }; + export const CommonErrorFragmentDoc = gql` fragment CommonError on Error { code @@ -125,6 +133,11 @@ export const CommonResponseFragmentDoc = gql` displayName } } + corrections { + id + type + correction + } } `; export const SuggestedQuestionsDocument = gql` @@ -435,14 +448,12 @@ export const CreateThreadResponseDocument = gql` createThreadResponse(threadId: $threadId, data: $data) { ...CommonResponse error { - code - shortMessage - message - stacktrace + ...CommonError } } } - ${CommonResponseFragmentDoc}`; + ${CommonResponseFragmentDoc} +${CommonErrorFragmentDoc}`; export type CreateThreadResponseMutationFn = Apollo.MutationFunction; /** @@ -600,4 +611,42 @@ export function useGetNativeSqlLazyQuery(baseOptions?: Apollo.LazyQueryHookOptio } export type GetNativeSqlQueryHookResult = ReturnType; export type GetNativeSqlLazyQueryHookResult = ReturnType; -export type GetNativeSqlQueryResult = Apollo.QueryResult; \ No newline at end of file +export type GetNativeSqlQueryResult = Apollo.QueryResult; +export const CreateCorrectedThreadResponseDocument = gql` + mutation CreateCorrectedThreadResponse($threadId: Int!, $data: CreateCorrectedThreadResponseInput!) { + createCorrectedThreadResponse(threadId: $threadId, data: $data) { + ...CommonResponse + error { + ...CommonError + } + } +} + ${CommonResponseFragmentDoc} +${CommonErrorFragmentDoc}`; +export type CreateCorrectedThreadResponseMutationFn = Apollo.MutationFunction; + +/** + * __useCreateCorrectedThreadResponseMutation__ + * + * To run a mutation, you first call `useCreateCorrectedThreadResponseMutation` within a React component and pass it any options that fit your needs. + * When your component renders, `useCreateCorrectedThreadResponseMutation` returns a tuple that includes: + * - A mutate function that you can call at any time to execute the mutation + * - An object with fields that represent the current status of the mutation's execution + * + * @param baseOptions options that will be passed into the mutation, supported options are listed on: https://www.apollographql.com/docs/react/api/react-hooks/#options-2; + * + * @example + * const [createCorrectedThreadResponseMutation, { data, loading, error }] = useCreateCorrectedThreadResponseMutation({ + * variables: { + * threadId: // value for 'threadId' + * data: // value for 'data' + * }, + * }); + */ +export function useCreateCorrectedThreadResponseMutation(baseOptions?: Apollo.MutationHookOptions) { + const options = {...defaultOptions, ...baseOptions} + return Apollo.useMutation(CreateCorrectedThreadResponseDocument, options); + } +export type CreateCorrectedThreadResponseMutationHookResult = ReturnType; +export type CreateCorrectedThreadResponseMutationResult = Apollo.MutationResult; +export type CreateCorrectedThreadResponseMutationOptions = Apollo.BaseMutationOptions; \ No newline at end of file diff --git a/wren-ui/src/apollo/client/graphql/home.ts b/wren-ui/src/apollo/client/graphql/home.ts index 70758cfb4..44117bb92 100644 --- a/wren-ui/src/apollo/client/graphql/home.ts +++ b/wren-ui/src/apollo/client/graphql/home.ts @@ -30,6 +30,11 @@ const COMMON_RESPONSE = gql` displayName } } + corrections { + id + type + correction + } } `; @@ -139,10 +144,7 @@ export const CREATE_THREAD_RESPONSE = gql` createThreadResponse(threadId: $threadId, data: $data) { ...CommonResponse error { - code - shortMessage - message - stacktrace + ...CommonError } } } @@ -180,3 +182,17 @@ export const GET_NATIVE_SQL = gql` nativeSql(responseId: $responseId) } `; + +export const CREATE_CORRECTED_THREAD_RESPONSE = gql` + mutation CreateCorrectedThreadResponse( + $threadId: Int! + $data: CreateCorrectedThreadResponseInput! + ) { + createCorrectedThreadResponse(threadId: $threadId, data: $data) { + ...CommonResponse + error { + ...CommonError + } + } + } +`; diff --git a/wren-ui/src/apollo/server/adaptors/wrenAIAdaptor.ts b/wren-ui/src/apollo/server/adaptors/wrenAIAdaptor.ts index 1e9a3c47a..472bed38f 100644 --- a/wren-ui/src/apollo/server/adaptors/wrenAIAdaptor.ts +++ b/wren-ui/src/apollo/server/adaptors/wrenAIAdaptor.ts @@ -80,12 +80,14 @@ export enum AskCandidateType { LLM = 'LLM', } -export enum ExplainType { +// The enum key's name refer to schema.ts > ReferenceType +// It helps mapping AI service enum values: selectItems, relation, filter, sortings, groupByKeys +export enum ExplanationType { + FIELD = 'selectItems', + QUERY_FROM = 'relation', FILTER = 'filter', - SELECT_ITEMS = 'selectItems', - RELATION = 'relation', - GROUP_BY_KEYS = 'groupByKeys', - SORTINGS = 'sortings', + SORTING = 'sortings', + GROUP_BY = 'groupByKeys', } // UI currently only support nl_expression @@ -129,18 +131,21 @@ export interface CorrectionObject { value: string; } -export interface AskCorrection { - before: CorrectionObject; +export interface AskCorrectionInput { + before: CorrectionObject; after: CorrectionObject; } -export interface AskStepWithCorrections extends AskStep { - corrections: AskCorrection[]; +export interface AskStepWithCorrectionsInput { + summary: string; + sql: string; + cte_name: string; + corrections: AskCorrectionInput[]; } export interface RegenerateAskDetailInput { description: string; - steps: AskStepWithCorrections[]; + steps: AskStepWithCorrectionsInput[]; } const getAIServiceError = (error: any) => { diff --git a/wren-ui/src/apollo/server/backgroundTrackers/threadResponseBackgroundTracker.ts b/wren-ui/src/apollo/server/backgroundTrackers/threadResponseBackgroundTracker.ts index 1738f7d80..7e7cb63c2 100644 --- a/wren-ui/src/apollo/server/backgroundTrackers/threadResponseBackgroundTracker.ts +++ b/wren-ui/src/apollo/server/backgroundTrackers/threadResponseBackgroundTracker.ts @@ -1,4 +1,8 @@ -import { AskResultStatus, IWrenAIAdaptor } from '../adaptors/wrenAIAdaptor'; +import { + AskDetailResult, + AskResultStatus, + IWrenAIAdaptor, +} from '../adaptors/wrenAIAdaptor'; import { IThreadResponseRepository, ThreadResponse, @@ -14,25 +18,29 @@ export class ThreadResponseBackgroundTracker extends BackgroundTracker { const jobs = Object.values(this.tasks).map( (threadResponse) => async () => { @@ -45,7 +53,7 @@ export class ThreadResponseBackgroundTracker extends BackgroundTracker + this.wrenAIAdaptor.getAskDetailResult(queryId), + sendTelemetry: ( + threadResponse: ThreadResponse, + result: AskDetailResult, + ) => { + this.telemetry.send_event('question_answered', { + question: threadResponse.question, + result, + }); + }, + }; + // For regenerated thread response + if (this.isRegenerated) { + strategy.name = 'RegeneratedThreadResponse'; + strategy.getAskDetailResult = (queryId) => + this.wrenAIAdaptor.getRegeneratedAskDetailResult(queryId); + strategy.sendTelemetry = (threadResponse, result) => { + this.telemetry.send_event('regenerated_question_answered', { + question: threadResponse.question, + corrections: threadResponse.corrections, + result, + }); + }; + } + return strategy; + } + private isFinalized = (status: AskResultStatus) => { return ( status === AskResultStatus.FAILED || diff --git a/wren-ui/src/apollo/server/repositories/threadResponseRepository.ts b/wren-ui/src/apollo/server/repositories/threadResponseRepository.ts index 47690d172..2f1647c84 100644 --- a/wren-ui/src/apollo/server/repositories/threadResponseRepository.ts +++ b/wren-ui/src/apollo/server/repositories/threadResponseRepository.ts @@ -4,7 +4,14 @@ import { IBasicRepository, IQueryOptions, } from './baseRepository'; -import { camelCase, isPlainObject, mapKeys, mapValues } from 'lodash'; +import { + camelCase, + snakeCase, + Dictionary, + isPlainObject, + mapKeys, + mapValues, +} from 'lodash'; import { AskResultStatus, WrenAIError } from '../adaptors/wrenAIAdaptor'; export interface DetailStep { @@ -19,6 +26,12 @@ export interface ThreadResponseDetail { steps: Array; } +export interface PrevCorrection { + id: number; + type: string; + correction: string; +} + export interface ThreadResponse { id: number; // ID threadId: number; // Reference to thread.id @@ -28,6 +41,7 @@ export interface ThreadResponse { status: string; // Thread response status detail: ThreadResponseDetail; // Thread response detail error: object; // Thread response error + corrections: PrevCorrection[]; // Previous thread response corrections } export interface ThreadResponseWithThreadContext extends ThreadResponse { @@ -63,27 +77,9 @@ export class ThreadResponseRepository query.orderBy('created_at', 'desc').limit(limit); } - return (await query) - .map((res) => { - // turn object keys into camelCase - return mapKeys(res, (_, key) => camelCase(key)); - }) - .map((res) => { - // JSON.parse detail and error - const detail = - res.detail && typeof res.detail === 'string' - ? JSON.parse(res.detail) - : res.detail; - const error = - res.error && typeof res.error === 'string' - ? JSON.parse(res.error) - : res.error; - return { - ...res, - detail: detail || null, - error: error || null, - }; - }) as ThreadResponseWithThreadContext[]; + return (await query).map((res) => + this.transformFromDBData(res), + ) as ThreadResponseWithThreadContext[]; } public async updateOne( @@ -108,13 +104,31 @@ export class ThreadResponseRepository return this.transformFromDBData(result); } + protected override transformToDBData = (data: any) => { + if (!isPlainObject(data)) { + throw new Error('Unexpected dbdata'); + } + const formattedData = mapValues(data, (value, key) => { + if (['error', 'detail', 'corrections'].includes(key)) { + // The value from Sqlite will be string type, while the value from PG is JSON object + if (value) { + return typeof value === 'string' ? value : JSON.stringify(value); + } else { + return value; + } + } + return value; + }) as Dictionary; + return mapKeys(formattedData, (_value, key) => snakeCase(key)); + }; + protected override transformFromDBData = (data: any): ThreadResponse => { if (!isPlainObject(data)) { throw new Error('Unexpected dbdata'); } const camelCaseData = mapKeys(data, (_value, key) => camelCase(key)); const formattedData = mapValues(camelCaseData, (value, key) => { - if (['error', 'detail'].includes(key)) { + if (['error', 'detail', 'corrections'].includes(key)) { // The value from Sqlite will be string type, while the value from PG is JSON object if (typeof value === 'string') { return value ? JSON.parse(value) : value; diff --git a/wren-ui/src/apollo/server/resolvers.ts b/wren-ui/src/apollo/server/resolvers.ts index b5b20d278..0a2e718fb 100644 --- a/wren-ui/src/apollo/server/resolvers.ts +++ b/wren-ui/src/apollo/server/resolvers.ts @@ -78,6 +78,7 @@ const resolvers = { updateThread: askingResolver.updateThread, deleteThread: askingResolver.deleteThread, createThreadResponse: askingResolver.createThreadResponse, + createCorrectedThreadResponse: askingResolver.createCorrectedThreadResponse, previewData: askingResolver.previewData, // Views diff --git a/wren-ui/src/apollo/server/resolvers/askingResolver.ts b/wren-ui/src/apollo/server/resolvers/askingResolver.ts index d76ee928c..691cb685c 100644 --- a/wren-ui/src/apollo/server/resolvers/askingResolver.ts +++ b/wren-ui/src/apollo/server/resolvers/askingResolver.ts @@ -54,6 +54,8 @@ export class AskingResolver { this.deleteThread = this.deleteThread.bind(this); this.listThreads = this.listThreads.bind(this); this.createThreadResponse = this.createThreadResponse.bind(this); + this.createCorrectedThreadResponse = + this.createCorrectedThreadResponse.bind(this); this.getResponse = this.getResponse.bind(this); this.getSuggestedQuestions = this.getSuggestedQuestions.bind(this); } @@ -188,6 +190,7 @@ export class AskingResolver { status: response.status, detail: response.detail, error: response.error, + corrections: response.corrections, }); return acc; @@ -257,6 +260,34 @@ export class AskingResolver { return response; } + public async createCorrectedThreadResponse( + _root: any, + args: { + threadId: number; + data: { + responseId: number; + corrections: { + id: number; + type: string; + stepIndex: number; + reference: string; + correction: string; + }[]; + }; + }, + ctx: IContext, + ): Promise { + const { threadId, data } = args; + + const askingService = ctx.askingService; + const response = await askingService.createCorrectedThreadResponse( + threadId, + data, + ); + ctx.telemetry.send_event('regenerate_asked_question', {}); + return response; + } + public async getResponse( _root: any, args: { responseId: number }, diff --git a/wren-ui/src/apollo/server/schema.ts b/wren-ui/src/apollo/server/schema.ts index 193f2c777..7c0904506 100644 --- a/wren-ui/src/apollo/server/schema.ts +++ b/wren-ui/src/apollo/server/schema.ts @@ -491,6 +491,14 @@ export const typeDefs = gql` LLM # LLM type candidate is created by LLM } + enum ReferenceType { + FIELD + QUERY_FROM + FILTER + SORTING + GROUP_BY + } + type ResultCandidate { type: ResultCandidateType! sql: String! @@ -519,6 +527,19 @@ export const typeDefs = gql` viewId: Int } + input CreateThreadResponseCorrectionInput { + id: Int! + stepIndex: Int! + type: ReferenceType! + reference: String! + correction: String! + } + + input CreateCorrectedThreadResponseInput { + responseId: Int! + corrections: [CreateThreadResponseCorrectionInput!]! + } + input ThreadUniqueWhereInput { id: Int! } @@ -549,6 +570,12 @@ export const typeDefs = gql` steps: [DetailStep!]! } + type CorrectionDetail { + id: Int! + type: ReferenceType! + correction: String! + } + type ThreadResponse { id: Int! question: String! @@ -556,6 +583,7 @@ export const typeDefs = gql` status: AskingTaskStatus! detail: ThreadResponseDetail error: Error + corrections: [CorrectionDetail!] } # Thread only consists of basic information of a thread @@ -757,6 +785,10 @@ export const typeDefs = gql` threadId: Int! data: CreateThreadResponseInput! ): ThreadResponse! + createCorrectedThreadResponse( + threadId: Int! + data: CreateCorrectedThreadResponseInput! + ): ThreadResponse! previewData(where: PreviewDataInput!): JSON! # Settings diff --git a/wren-ui/src/apollo/server/services/askingService.ts b/wren-ui/src/apollo/server/services/askingService.ts index 4ec533ea5..27a853391 100644 --- a/wren-ui/src/apollo/server/services/askingService.ts +++ b/wren-ui/src/apollo/server/services/askingService.ts @@ -3,6 +3,8 @@ import { IWrenAIAdaptor, AskResultStatus, AskHistory, + ExpressionType, + ExplanationType, } from '@server/adaptors/wrenAIAdaptor'; import { IDeployService } from './deployService'; import { IProjectService } from './projectService'; @@ -13,7 +15,7 @@ import { ThreadResponseWithThreadContext, } from '../repositories/threadResponseRepository'; import { getLogger } from '@server/utils'; -import { isEmpty, isNil } from 'lodash'; +import { groupBy, isEmpty, isNil } from 'lodash'; import { format } from 'sql-formatter'; import { Telemetry } from '../telemetry/telemetry'; import { IViewRepository, View } from '../repositories'; @@ -41,6 +43,19 @@ export interface AskingDetailTaskInput { viewId?: number; } +export interface CorrectionInput { + id: number; + type: string; + stepIndex: number; + reference: string; + correction: string; +} + +export interface CorrectedDetailTaskInput { + responseId: number; + corrections: CorrectionInput[]; +} + export interface IAskingService { /** * Asking task. @@ -63,6 +78,10 @@ export interface IAskingService { threadId: number, input: AskingDetailTaskInput, ): Promise; + createCorrectedThreadResponse( + threadId: number, + input: CorrectedDetailTaskInput, + ): Promise; getResponsesWithThread( threadId: number, ): Promise; @@ -141,6 +160,7 @@ export class AskingService implements IAskingService { private threadRepository: IThreadRepository; private threadResponseRepository: IThreadResponseRepository; private backgroundTracker: ThreadResponseBackgroundTracker; + private regeneratedBackgroundTracker: ThreadResponseBackgroundTracker; private queryService: IQueryService; private telemetry: Telemetry; @@ -176,6 +196,12 @@ export class AskingService implements IAskingService { wrenAIAdaptor, threadResponseRepository, }); + this.regeneratedBackgroundTracker = new ThreadResponseBackgroundTracker({ + telemetry, + wrenAIAdaptor, + threadResponseRepository, + isRegenerated: true, + }); } public async initialize() { @@ -190,6 +216,10 @@ export class AskingService implements IAskingService { `Initialization: adding unfinished thread responses (total: ${unfinishedThreadResponses.length}) to background tracker`, ); for (const threadResponse of unfinishedThreadResponses) { + if (threadResponse.corrections !== null) { + this.regeneratedBackgroundTracker.addTask(threadResponse); + continue; + } this.backgroundTracker.addTask(threadResponse); } } @@ -342,6 +372,63 @@ export class AskingService implements IAskingService { return threadResponse; } + public async createCorrectedThreadResponse( + threadId: number, + input: CorrectedDetailTaskInput, + ): Promise { + const thread = await this.threadRepository.findOneBy({ + id: threadId, + }); + + const baseThreadResponse = await this.threadResponseRepository.findOneBy({ + id: input.responseId, + }); + + if (!baseThreadResponse) { + throw new Error(`Thread response ${input.responseId} not found`); + } + + const correctionsMap = groupBy(input.corrections, 'stepIndex'); + const response = await this.wrenAIAdaptor.regenerateAskDetail({ + description: baseThreadResponse.detail.description, + steps: baseThreadResponse.detail.steps.map((step, index) => ({ + sql: step.sql, + summary: step.summary, + cte_name: step.cteName, + corrections: (correctionsMap[index] || []).map((item) => ({ + before: { + type: ExplanationType[item.type], + value: item.reference, + }, + after: { + // Only NL_EXPRESSION is supported for now + type: ExpressionType.NL_EXPRESSION, + value: item.correction, + }, + })), + })), + }); + + const threadResponse = await this.threadResponseRepository.createOne({ + threadId: thread.id, + queryId: response.queryId, + question: baseThreadResponse.question, + summary: baseThreadResponse.summary, + status: AskResultStatus.UNDERSTANDING, + corrections: input.corrections.map((item) => ({ + id: item.id, + type: item.type, + correction: item.correction, + })), + }); + + // 3. put the task into background tracker + this.regeneratedBackgroundTracker.addTask(threadResponse); + + // return the task id + return threadResponse; + } + public async getResponsesWithThread(threadId: number) { return this.threadResponseRepository.getResponsesWithThread(threadId); } diff --git a/wren-ui/src/components/pages/home/thread/AnswerResult.tsx b/wren-ui/src/components/pages/home/thread/AnswerResult.tsx index 0a8aa90aa..8af5961f4 100644 --- a/wren-ui/src/components/pages/home/thread/AnswerResult.tsx +++ b/wren-ui/src/components/pages/home/thread/AnswerResult.tsx @@ -1,19 +1,27 @@ import { useState } from 'react'; import Link from 'next/link'; -import { Col, Button, Row, Skeleton, Typography } from 'antd'; +import { Col, Button, Row, Skeleton, Typography, Divider, Tag } from 'antd'; import styled from 'styled-components'; import { Path } from '@/utils/enum'; import CheckCircleFilled from '@ant-design/icons/CheckCircleFilled'; +import ShareAltOutlined from '@ant-design/icons/ShareAltOutlined'; import QuestionCircleOutlined from '@ant-design/icons/QuestionCircleOutlined'; import SaveOutlined from '@ant-design/icons/SaveOutlined'; import FileDoneOutlined from '@ant-design/icons/FileDoneOutlined'; import StepContent from '@/components/pages/home/thread/StepContent'; import FeedbackLayout from '@/components/pages/home/thread/feedback'; +import { + AskingTaskStatus, + ThreadResponse, +} from '@/apollo/client/graphql/__types__'; +import { makeIterable } from '@/utils/iteration'; +import { getReferenceIcon } from '@/components/pages/home/thread/feedback/utils'; const { Title, Text } = Typography; const Wrapper = styled.div` width: 680px; + flex-shrink: 0; `; const StyledAnswer = styled(Typography)` @@ -45,60 +53,107 @@ const StyledQuestion = styled(Row)` `; interface Props { - loading: boolean; - question: string; - description: string; - answerResultSteps: Array<{ - summary: string; - sql: string; - }>; - fullSql: string; - threadResponseId: number; - onOpenSaveAsViewModal: (data: { sql: string; responseId: number }) => void; + threadResponse: ThreadResponse; isLastThreadResponse: boolean; + onOpenSaveAsViewModal: (data: { sql: string; responseId: number }) => void; onTriggerScrollToBottom: () => void; - summary: string; - view?: { - id: number; - displayName: string; - }; + onSubmitReviewDrawer: (variables: any) => Promise; } +const CorrectionTemplate = ({ id, type, correction }) => { + return ( +
+ + {getReferenceIcon(type)} + {id} + + {correction} +
+ ); +}; +const CorrectionIterator = makeIterable(CorrectionTemplate); +const RegenerateInformation = (props: ThreadResponse) => { + const { question, corrections } = props; + const [collapse, setCollapse] = useState(false); + const collapseText = collapse ? 'Hide' : 'Show'; + return ( +
+
+ +
Regenerated answer from
+ setCollapse(!collapse)} + > + {collapseText} feedbacks + +
+ + {question} + + {collapse && ( +
+ +
+ +
Feedbacks
+
+ +
+ )} +
+ ); +}; + +const QuestionInformation = (props) => { + const { question } = props; + const [ellipsis, setEllipsis] = useState(true); + return ( + setEllipsis(!ellipsis)}> + + + Question: + + + + {question} + + + + ); +}; + export default function AnswerResult(props: Props) { const { - loading, - question, - description, - answerResultSteps, - fullSql, - threadResponseId, + threadResponse, isLastThreadResponse, onOpenSaveAsViewModal, onTriggerScrollToBottom, - summary, - view, + onSubmitReviewDrawer, } = props; + const { id: responseId, summary, status, corrections } = threadResponse; + const { + view, + steps, + description, + sql: fullSql, + } = threadResponse?.detail || {}; + const isViewSaved = !!view; + const isRegenerated = !!corrections; + const loading = status !== AskingTaskStatus.FINISHED; - const [ellipsis, setEllipsis] = useState(true); + const Information = isRegenerated + ? RegenerateInformation + : QuestionInformation; return ( - setEllipsis(!ellipsis)}> - - - Question: - - - - {question} - - - + {summary} @@ -112,15 +167,15 @@ export default function AnswerResult(props: Props) { Summary
{description}
- {(answerResultSteps || []).map((step, index) => ( + {(steps || []).map((step, index) => ( @@ -146,10 +201,7 @@ export default function AnswerResult(props: Props) { size="small" icon={} onClick={() => - onOpenSaveAsViewModal({ - sql: fullSql, - responseId: threadResponseId, - }) + onOpenSaveAsViewModal({ sql: fullSql, responseId }) } > Save as View @@ -157,6 +209,8 @@ export default function AnswerResult(props: Props) { )} } + threadResponse={threadResponse} + onSubmitReviewDrawer={onSubmitReviewDrawer} />
); diff --git a/wren-ui/src/components/pages/home/thread/feedback/FeedbackSideFloat.tsx b/wren-ui/src/components/pages/home/thread/feedback/FeedbackSideFloat.tsx index ba52bac84..db419ccaa 100644 --- a/wren-ui/src/components/pages/home/thread/feedback/FeedbackSideFloat.tsx +++ b/wren-ui/src/components/pages/home/thread/feedback/FeedbackSideFloat.tsx @@ -3,6 +3,7 @@ import { useMemo } from 'react'; import styled from 'styled-components'; import { Button, Popconfirm } from 'antd'; import { FileTextOutlined } from '@ant-design/icons'; +import { Reference } from './utils'; const StyledFeedbackSideFloat = styled.div` position: relative; @@ -17,14 +18,18 @@ const StyledFeedbackSideFloat = styled.div` interface Props { className?: string; - references: any[]; + references: Reference[]; onOpenReviewDrawer: () => void; - onResetAllChanges: () => void; + onResetAllCorrectionPrompts: () => void; } export default function FeedbackSideFloat(props: Props) { - const { className, references, onOpenReviewDrawer, onResetAllChanges } = - props; + const { + className, + references, + onOpenReviewDrawer, + onResetAllCorrectionPrompts, + } = props; const changedReferences = useMemo(() => { return (references || []).filter((item) => !!item.correctionPrompt); @@ -51,7 +56,7 @@ export default function FeedbackSideFloat(props: Props) { title="Are you sure?" okText="Confirm" okButtonProps={{ danger: true }} - onConfirm={onResetAllChanges} + onConfirm={onResetAllCorrectionPrompts} > - )} - - } - threadResponse={threadResponse} - onSubmitReviewDrawer={onSubmitReviewDrawer} - /> - +
+ {error ? ( + + ) : ( + + + + + {summary} + + + } + bodySlot={ + + + + + Summary + +
{description}
+ {(steps || []).map((step, index) => ( + + ))} +
+ {isViewSaved ? ( +
+ + Generated from saved view{' '} + + {view.displayName} + +
+ ) : ( + + )} +
+ } + threadResponse={threadResponse} + onSubmitReviewDrawer={onSubmitReviewDrawer} + onTriggerThreadResponseExplain={onTriggerThreadResponseExplain} + /> +
+ )} +
); } diff --git a/wren-ui/src/components/pages/home/thread/feedback/FeedbackSideFloat.tsx b/wren-ui/src/components/pages/home/thread/feedback/FeedbackSideFloat.tsx index db419ccaa..7af3cb242 100644 --- a/wren-ui/src/components/pages/home/thread/feedback/FeedbackSideFloat.tsx +++ b/wren-ui/src/components/pages/home/thread/feedback/FeedbackSideFloat.tsx @@ -7,7 +7,6 @@ import { Reference } from './utils'; const StyledFeedbackSideFloat = styled.div` position: relative; - width: 325px; .feedbackSideFloat-title { position: absolute; diff --git a/wren-ui/src/components/pages/home/thread/feedback/ReferenceSideFloat.tsx b/wren-ui/src/components/pages/home/thread/feedback/ReferenceSideFloat.tsx index f0de93e9e..6a3cc3a57 100644 --- a/wren-ui/src/components/pages/home/thread/feedback/ReferenceSideFloat.tsx +++ b/wren-ui/src/components/pages/home/thread/feedback/ReferenceSideFloat.tsx @@ -1,16 +1,26 @@ import clsx from 'clsx'; +import { groupBy } from 'lodash'; import { useMemo, useState } from 'react'; import styled from 'styled-components'; -import { Tag, Typography, Button, Input } from 'antd'; -import { EditOutlined } from '@ant-design/icons'; +import { Tag, Typography, Button, Input, Alert } from 'antd'; +import { + EditOutlined, + CloseCircleFilled, + ReloadOutlined, + InfoCircleFilled, +} from '@ant-design/icons'; import { QuoteIcon } from '@/utils/icons'; import { makeIterable } from '@/utils/iteration'; -import { Reference, getReferenceIcon } from './utils'; -import { ReferenceType } from '@/apollo/client/graphql/__types__'; +import { + REFERENCE_ORDERS, + Reference, + getReferenceIcon, + getReferenceName, +} from './utils'; +import { ERROR_CODES } from '@/utils/errorHandler'; const StyledReferenceSideFloat = styled.div` position: relative; - width: 330px; .referenceSideFloat-title { position: absolute; @@ -19,23 +29,48 @@ const StyledReferenceSideFloat = styled.div` } `; +const StyledAlert = styled(Alert)` + padding: 8px 12px 12px; + .ant-alert-icon { + font-size: 14px; + margin-right: 8px; + margin-top: 4px; + } + .ant-alert-message { + font-size: 14px; + line-height: 14px; + margin-top: 4px; + margin-bottom: 8px; + } + .ant-alert-description { + font-size: 12px; + line-height: 16px; + color: var(--gray-8); + } +`; + interface Props { references: Reference[]; - onSaveCorrectionPrompt?: (id: string, value: string) => void; + error?: Record; + onSaveCorrectionPrompt: (id: string, value: string) => void; + onTriggerExplanation: () => void; } -const COLLAPSE_LIMIT = 3; - -const ReferenceSummaryTemplate = ({ id, title, type, correctionPrompt }) => { +const ReferenceSummaryTemplate = ({ + summary, + type, + referenceNum, + correctionPrompt, +}) => { const isRevise = !!correctionPrompt; return (
{getReferenceIcon(type)} - {id} + {referenceNum} - {title} + {summary}
); @@ -64,9 +99,10 @@ const GroupReferenceTemplate = ({ }; const ReferenceTemplate = ({ - id, - title, type, + summary, + referenceId, + referenceNum, correctionPrompt, saveCorrectionPrompt, }) => { @@ -79,7 +115,7 @@ const ReferenceTemplate = ({ }; const handleEdit = () => { - saveCorrectionPrompt(id, value); + saveCorrectionPrompt(referenceId, value); setIsEdit(false); setValue(''); }; @@ -89,12 +125,12 @@ const ReferenceTemplate = ({
{getReferenceIcon(type)} - {id} + {referenceNum}
- {title} + {summary} {isRevise ? ( '(feedback suggested)' @@ -136,38 +172,12 @@ const ReferenceIterator = makeIterable(ReferenceTemplate); const References = (props: Props) => { const { references, onSaveCorrectionPrompt } = props; - - const fieldReferences = references.filter( - (ref) => ref.type === ReferenceType.FIELD, - ); - const queryFromReferences = references.filter( - (ref) => ref.type === ReferenceType.QUERY_FROM, - ); - const filterReferences = references.filter( - (ref) => ref.type === ReferenceType.FILTER, - ); - const sortingReferences = references.filter( - (ref) => ref.type === ReferenceType.SORTING, - ); - const groupByReferences = references.filter( - (ref) => ref.type === ReferenceType.GROUP_BY, - ); - - const resources = [ - { name: 'Fields', type: ReferenceType.FIELD, data: fieldReferences }, - { - name: 'Query from', - type: ReferenceType.QUERY_FROM, - data: queryFromReferences, - }, - { name: 'Filter', type: ReferenceType.FILTER, data: filterReferences }, - { name: 'Sorting', type: ReferenceType.SORTING, data: sortingReferences }, - { - name: 'Group by', - type: ReferenceType.GROUP_BY, - data: groupByReferences, - }, - ]; + const referencesByGroup = groupBy(references, 'type'); + const resources = REFERENCE_ORDERS.map((type) => ({ + type, + name: getReferenceName(type), + data: referencesByGroup[type] || [], + })); return (
{ }; export default function ReferenceSideFloat(props: Props) { - const { references } = props; + const { references, error, onTriggerExplanation } = props; const [collapse, setCollapse] = useState(false); const referencesSummary = useMemo( - () => references.slice(0, COLLAPSE_LIMIT), + () => + references.reduce((result, reference) => { + if (!result[reference.stepIndex]) { + result[reference.stepIndex] = reference; + } + return result; + }, []), [collapse, references], ); @@ -195,7 +211,34 @@ export default function ReferenceSideFloat(props: Props) { setCollapse(!collapse); }; - if (references.length === 0) return null; + if (error) { + // If the thread response was created before the release of the Feedback Loop Feature, + // the explanation will be migrated with an error code OLD_VERSION. + // In this case, users will need to manually trigger the explanation. + const isOldVersion = error.code === ERROR_CODES.OLD_VERSION; + const icon = isOldVersion ? : ; + const type = isOldVersion ? 'info' : 'error'; + return ( + } + onClick={onTriggerExplanation} + > + Retry + + } + /> + ); + } else if (references.length === 0) return null; return (
diff --git a/wren-ui/src/components/pages/home/thread/feedback/ReviewDrawer.tsx b/wren-ui/src/components/pages/home/thread/feedback/ReviewDrawer.tsx index 83e31894a..ef8e5084e 100644 --- a/wren-ui/src/components/pages/home/thread/feedback/ReviewDrawer.tsx +++ b/wren-ui/src/components/pages/home/thread/feedback/ReviewDrawer.tsx @@ -37,9 +37,10 @@ const StyledOriginal = styled.div` `; const ReviewTemplate = ({ - id, - title, type, + summary, + referenceId, + referenceNum, correctionPrompt, saveCorrectionPrompt, removeCorrectionPrompt, @@ -53,11 +54,11 @@ const ReviewTemplate = ({ }; const openDelete = () => { - removeCorrectionPrompt(id); + removeCorrectionPrompt(referenceId); }; const handleEdit = (event) => { - saveCorrectionPrompt(id, event.target.value); + saveCorrectionPrompt(referenceId, event.target.value); setIsEdit(false); }; @@ -72,14 +73,14 @@ const ReviewTemplate = ({ Reference: - {title} + {summary}
{getReferenceIcon(type)} - {id} + {referenceNum}
@@ -146,9 +147,10 @@ export default function ReviewDrawer(props: Props) { const data: CreateCorrectedThreadResponseInput = { responseId: threadResponseId, corrections: changedReferences.map((reference) => ({ - id: reference.id, + id: reference.referenceId, type: reference.type, - reference: reference.title, + reference: reference.summary, + referenceNum: reference.referenceNum, stepIndex: reference.stepIndex, correction: reference.correctionPrompt, })), diff --git a/wren-ui/src/components/pages/home/thread/feedback/index.tsx b/wren-ui/src/components/pages/home/thread/feedback/index.tsx index 7757e0661..04f1bbd85 100644 --- a/wren-ui/src/components/pages/home/thread/feedback/index.tsx +++ b/wren-ui/src/components/pages/home/thread/feedback/index.tsx @@ -1,15 +1,16 @@ import { createContext, useContext, useMemo, useState } from 'react'; +import { sortBy } from 'lodash'; +import { Skeleton } from 'antd'; import ReferenceSideFloat from '@/components/pages/home/thread/feedback/ReferenceSideFloat'; import FeedbackSideFloat from '@/components/pages/home/thread/feedback/FeedbackSideFloat'; import ReviewDrawer from '@/components/pages/home/thread/feedback/ReviewDrawer'; import useDrawerAction from '@/hooks/useDrawerAction'; -import { - ReferenceType, - ThreadResponse, -} from '@/apollo/client/graphql/__types__'; +import { ThreadResponse } from '@/apollo/client/graphql/__types__'; +import { Reference, REFERENCE_ORDERS } from './utils'; +import { getIsExplainFinished } from '@/hooks/useAskPrompt'; type ContextProps = { - references: any[]; + references: Reference[]; } | null; export const FeedbackContext = createContext({ @@ -25,63 +26,17 @@ interface Props { bodySlot: React.ReactNode; threadResponse: ThreadResponse; onSubmitReviewDrawer: (variables: any) => Promise; + onTriggerThreadResponseExplain: (variables: any) => Promise; } -const data = [ - { - id: 1, - type: ReferenceType.FIELD, - stepIndex: 0, - title: - "Selects the 'City' column from the 'customer_data' dataset to display the city name.", - referenceNum: 1, - }, - { - id: 2, - type: ReferenceType.FIELD, - stepIndex: 0, - title: 'Reference 2', - referenceNum: 2, - }, - { - id: 3, - type: ReferenceType.QUERY_FROM, - stepIndex: 1, - title: 'Reference 3', - referenceNum: 3, - }, - { - id: 4, - type: ReferenceType.QUERY_FROM, - stepIndex: 1, - title: 'Reference 4', - referenceNum: 4, - }, - { - id: 5, - type: ReferenceType.FILTER, - stepIndex: 2, - title: 'Reference 4', - referenceNum: 4, - }, - { - id: 6, - type: ReferenceType.SORTING, - stepIndex: 2, - title: 'Reference 4', - referenceNum: 4, - }, - { - id: 7, - type: ReferenceType.GROUP_BY, - stepIndex: 2, - title: 'Reference 4', - referenceNum: 4, - }, -]; - export default function Feedback(props: Props) { - const { headerSlot, bodySlot, threadResponse, onSubmitReviewDrawer } = props; + const { + headerSlot, + bodySlot, + threadResponse, + onSubmitReviewDrawer, + onTriggerThreadResponseExplain, + } = props; const [correctionPrompts, setCorrectionPrompts] = useState({}); const reviewDrawer = useDrawerAction(); @@ -98,31 +53,45 @@ export default function Feedback(props: Props) { setCorrectionPrompts({}); }; + const triggerExplanation = () => { + onTriggerThreadResponseExplain({ responseId: threadResponse.id }); + }; + + const loading = useMemo( + () => !getIsExplainFinished(threadResponse?.explain?.status), + [threadResponse?.explain?.status], + ); + const error = useMemo(() => { + return threadResponse?.explain?.error || null; + }, [threadResponse?.explain?.error]); const references = useMemo(() => { if (!threadResponse?.detail) return []; - return threadResponse.detail.steps.flatMap((_, index) => { - // TODO: change to real step reference's data - const references = data.filter((item) => item.stepIndex === index); - return references.map((reference) => ({ - id: reference.id, - title: reference.title, - type: reference.type, - stepIndex: reference.stepIndex, - correctionPrompt: correctionPrompts[reference.id], + const result = threadResponse.detail.steps.flatMap((step, index) => { + if (step.references === null) return []; + return step.references.map((reference) => ({ + ...reference, + stepIndex: index, + correctionPrompt: correctionPrompts[reference.referenceId], })); }); + // Generate reference number for each reference + return sortBy(result, (reference) => + REFERENCE_ORDERS.indexOf(reference.type), + ).map((reference, index) => ({ + referenceNum: index + 1, + ...reference, + })); }, [threadResponse?.detail, correctionPrompts]); const contextValue = { references, - saveCorrectionPrompt, }; return (
{headerSlot} -
+
{bodySlot} -
- +
+ + +
{ + return ( + { + [ReferenceType.FIELD]: 'Fields', + [ReferenceType.QUERY_FROM]: 'Query from', + [ReferenceType.FILTER]: 'Filter', + [ReferenceType.SORTING]: 'Sorting', + [ReferenceType.GROUP_BY]: 'Group by', + }[type] || null + ); +}; + export const getReferenceIcon = (type) => { return ( { diff --git a/wren-ui/src/components/pages/home/thread/index.tsx b/wren-ui/src/components/pages/home/thread/index.tsx index b4c1b7653..791a18beb 100644 --- a/wren-ui/src/components/pages/home/thread/index.tsx +++ b/wren-ui/src/components/pages/home/thread/index.tsx @@ -1,5 +1,5 @@ import React, { useEffect, useRef } from 'react'; -import { Alert, Divider } from 'antd'; +import { Divider } from 'antd'; import styled from 'styled-components'; import AnswerResult from './AnswerResult'; import { IterableComponent, makeIterable } from '@/utils/iteration'; @@ -12,6 +12,7 @@ interface Props { data: DetailedThread; onOpenSaveAsViewModal: (data: { sql: string; responseId: number }) => void; onSubmitReviewDrawer: (variables: any) => Promise; + onTriggerThreadResponseExplain: (variables: any) => Promise; } const StyledThread = styled.div` @@ -38,38 +39,32 @@ const AnswerResultTemplate: React.FC< onOpenSaveAsViewModal: (data: { sql: string; responseId: number }) => void; onTriggerScrollToBottom: () => void; onSubmitReviewDrawer: (variables: any) => Promise; + onTriggerThreadResponseExplain: (variables: any) => Promise; } > = ({ onOpenSaveAsViewModal, onTriggerScrollToBottom, onSubmitReviewDrawer, + onTriggerThreadResponseExplain, data, index, ...threadResponse }) => { const lastResponseId = data[data.length - 1].id; const isLastThreadResponse = threadResponse.id === lastResponseId; - const { id, error } = threadResponse; + const { id } = threadResponse; return ( {index > 0 && } - {error ? ( - - ) : ( - - )} + ); }; @@ -77,7 +72,12 @@ const AnswerResultTemplate: React.FC< const AnswerResultIterator = makeIterable(AnswerResultTemplate); export default function Thread(props: Props) { - const { data, onOpenSaveAsViewModal, onSubmitReviewDrawer } = props; + const { + data, + onOpenSaveAsViewModal, + onSubmitReviewDrawer, + onTriggerThreadResponseExplain, + } = props; const divRef = useRef(null); const triggerScrollToBottom = () => { @@ -109,6 +109,7 @@ export default function Thread(props: Props) { onOpenSaveAsViewModal={onOpenSaveAsViewModal} onTriggerScrollToBottom={triggerScrollToBottom} onSubmitReviewDrawer={onSubmitReviewDrawer} + onTriggerThreadResponseExplain={onTriggerThreadResponseExplain} /> ); diff --git a/wren-ui/src/hooks/useAskPrompt.tsx b/wren-ui/src/hooks/useAskPrompt.tsx index f6c2e6249..e0a3654be 100644 --- a/wren-ui/src/hooks/useAskPrompt.tsx +++ b/wren-ui/src/hooks/useAskPrompt.tsx @@ -1,18 +1,44 @@ import { useEffect, useMemo } from 'react'; -import { AskingTaskStatus } from '@/apollo/client/graphql/__types__'; +import { + AskingTaskStatus, + ExplainTaskStatus, +} from '@/apollo/client/graphql/__types__'; import { useAskingTaskLazyQuery, useCancelAskingTaskMutation, useCreateAskingTaskMutation, } from '@/apollo/client/graphql/home.generated'; -export const getIsFinished = (status: AskingTaskStatus) => +export const getIsExplainFinished = (status: ExplainTaskStatus) => + [ExplainTaskStatus.FINISHED, ExplainTaskStatus.FAILED].includes(status); + +export const getIsAskingFinished = (status: AskingTaskStatus) => [ AskingTaskStatus.FINISHED, AskingTaskStatus.FAILED, AskingTaskStatus.STOPPED, ].includes(status); +export const checkExplainExisted = (explain?: { + queryId?: string; + status?: ExplainTaskStatus; +}) => { + // if the queryId is not empty, it means the question is explainable + return !!explain?.queryId ? explain.status : undefined; +}; + +export const getIsFinished = ( + askingStatus: AskingTaskStatus, + explainStatus?: ExplainTaskStatus, +) => { + const isAskingFinished = getIsAskingFinished(askingStatus); + if (explainStatus) { + const isExplainFinished = getIsExplainFinished(explainStatus); + return isAskingFinished && isExplainFinished; + } + return isAskingFinished; +}; + export default function useAskPrompt(threadId?: number) { const [createAskingTask, createAskingTaskResult] = useCreateAskingTaskMutation(); diff --git a/wren-ui/src/pages/api/graphql.ts b/wren-ui/src/pages/api/graphql.ts index 830f5a8d6..b0ed3246a 100644 --- a/wren-ui/src/pages/api/graphql.ts +++ b/wren-ui/src/pages/api/graphql.ts @@ -33,6 +33,7 @@ import { DataSourceMetadataService, QueryService, } from '@/apollo/server/services'; +import { ThreadResponseExplainRepository } from '@/apollo/server/repositories/threadResponseExplainRepository'; const serverConfig = getConfig(); const logger = getLogger('APOLLO'); @@ -63,6 +64,9 @@ const bootstrapServer = async () => { const deployLogRepository = new DeployLogRepository(knex); const threadRepository = new ThreadRepository(knex); const threadResponseRepository = new ThreadResponseRepository(knex); + const threadResponseExplainRepository = new ThreadResponseExplainRepository( + knex, + ); const viewRepository = new ViewRepository(knex); const schemaChangeRepository = new SchemaChangeRepository(knex); @@ -115,11 +119,13 @@ const bootstrapServer = async () => { const askingService = new AskingService({ telemetry, wrenAIAdaptor, + ibisAdaptor, deployService, projectService, viewRepository, threadRepository, threadResponseRepository, + threadResponseExplainRepository, queryService, }); @@ -178,6 +184,7 @@ const bootstrapServer = async () => { viewRepository, deployRepository: deployLogRepository, schemaChangeRepository, + threadResponseExplainRepository, }), }); await apolloServer.start(); diff --git a/wren-ui/src/pages/home/[id].tsx b/wren-ui/src/pages/home/[id].tsx index ab7e8b8b2..bf256ef47 100644 --- a/wren-ui/src/pages/home/[id].tsx +++ b/wren-ui/src/pages/home/[id].tsx @@ -8,17 +8,22 @@ import SiderLayout from '@/components/layouts/SiderLayout'; import Prompt from '@/components/pages/home/prompt'; import { useCreateCorrectedThreadResponseMutation, + useCreateThreadResponseExplainMutation, useCreateThreadResponseMutation, useThreadQuery, useThreadResponseLazyQuery, } from '@/apollo/client/graphql/home.generated'; -import useAskPrompt, { getIsFinished } from '@/hooks/useAskPrompt'; +import useAskPrompt, { + getIsFinished, + checkExplainExisted, +} from '@/hooks/useAskPrompt'; import useModalAction from '@/hooks/useModalAction'; import Thread from '@/components/pages/home/thread'; import SaveAsViewModal from '@/components/modals/SaveAsViewModal'; import { useCreateViewMutation } from '@/apollo/client/graphql/view.generated'; import { CreateCorrectedThreadResponseInput, + CreateThreadResponseExplainWhereInput, CreateThreadResponseInput, } from '@/apollo/client/graphql/__types__'; @@ -55,6 +60,9 @@ export default function HomeThread() { }; }); }; + const [createThreadResponseExplain] = useCreateThreadResponseExplainMutation({ + onError: (error) => console.error(error), + }); const [createThreadResponse] = useCreateThreadResponseMutation({ onCompleted(next) { const nextResponse = next.createThreadResponse; @@ -91,13 +99,18 @@ export default function HomeThread() { [threadResponseResult.data], ); const isFinished = useMemo( - () => getIsFinished(threadResponse?.status), + () => + getIsFinished( + threadResponse?.status, + checkExplainExisted(threadResponse?.explain), + ), [threadResponse], ); useEffect(() => { const unfinishedRespose = (thread?.responses || []).find( - (response) => !getIsFinished(response.status), + (response) => + !getIsFinished(response.status, checkExplainExisted(response?.explain)), ); if (unfinishedRespose) { @@ -106,8 +119,21 @@ export default function HomeThread() { }, [thread]); useEffect(() => { - if (isFinished) threadResponseResult.stopPolling(); - }, [isFinished]); + if (isFinished) { + threadResponseResult.stopPolling(); + + const isSuccessBreakdown = threadResponse?.error === null; + const isExplainable = + threadResponse?.explain && + threadResponse?.explain?.error === null && + threadResponse?.explain.queryId === null; + if (isSuccessBreakdown && isExplainable) { + createThreadResponseExplain({ + variables: { where: { responseId: threadResponse.id } }, + }).then(() => threadResponseResult.startPolling(1000)); + } + } + }, [isFinished, threadResponse]); const onSelect = async (payload: CreateThreadResponseInput) => { try { @@ -139,12 +165,26 @@ export default function HomeThread() { } }; + const onTriggerThreadResponseExplain = async ( + payload: CreateThreadResponseExplainWhereInput, + ) => { + try { + await createThreadResponseExplain({ + variables: { where: payload }, + }); + fetchThreadResponse({ variables: payload }); + } catch (error) { + console.error(error); + } + }; + return (
Date: Fri, 19 Jul 2024 16:03:38 +0800 Subject: [PATCH 07/13] fix(wren-ui): check explain finish condition, not polling thread response when explain status is undefined --- wren-ui/src/hooks/useAskPrompt.tsx | 2 +- wren-ui/src/pages/home/[id].tsx | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/wren-ui/src/hooks/useAskPrompt.tsx b/wren-ui/src/hooks/useAskPrompt.tsx index e0a3654be..698a0c378 100644 --- a/wren-ui/src/hooks/useAskPrompt.tsx +++ b/wren-ui/src/hooks/useAskPrompt.tsx @@ -32,7 +32,7 @@ export const getIsFinished = ( explainStatus?: ExplainTaskStatus, ) => { const isAskingFinished = getIsAskingFinished(askingStatus); - if (explainStatus) { + if (explainStatus !== undefined) { const isExplainFinished = getIsExplainFinished(explainStatus); return isAskingFinished && isExplainFinished; } diff --git a/wren-ui/src/pages/home/[id].tsx b/wren-ui/src/pages/home/[id].tsx index bf256ef47..f6e4cf47b 100644 --- a/wren-ui/src/pages/home/[id].tsx +++ b/wren-ui/src/pages/home/[id].tsx @@ -172,7 +172,7 @@ export default function HomeThread() { await createThreadResponseExplain({ variables: { where: payload }, }); - fetchThreadResponse({ variables: payload }); + await fetchThreadResponse({ variables: payload }); } catch (error) { console.error(error); } From d481ac14a716d646d856b10dff0f0374fa8cf464 Mon Sep 17 00:00:00 2001 From: onlyjackfrost Date: Fri, 19 Jul 2024 17:02:12 +0800 Subject: [PATCH 08/13] fix distinctOn not supported by sqlite --- .../threadResponseExplainRepository.ts | 39 ++++++++++++++----- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/wren-ui/src/apollo/server/repositories/threadResponseExplainRepository.ts b/wren-ui/src/apollo/server/repositories/threadResponseExplainRepository.ts index 2939fad7e..784239f32 100644 --- a/wren-ui/src/apollo/server/repositories/threadResponseExplainRepository.ts +++ b/wren-ui/src/apollo/server/repositories/threadResponseExplainRepository.ts @@ -6,7 +6,8 @@ import { } from './baseRepository'; import { camelCase, isPlainObject, mapKeys, mapValues } from 'lodash'; import { ExplainPipelineStatus } from '../adaptors/wrenAIAdaptor'; - +import { getConfig } from '../config'; +const config = getConfig(); export interface DetailStep { summary: string; sql: string; @@ -96,20 +97,40 @@ export class ThreadResponseExplainRepository public async findAllByThread( threadId: number, ): Promise { + if (config.dbType === 'pg') { + return this.knex('thread_response as tr') + .join( + this.knex(this.tableName) + .distinctOn('thread_response_id') + .select('id', 'thread_response_id', 'detail', 'error', 'created_at') + .orderBy([ + 'thread_response_id', + { column: 'created_at', order: 'desc' }, + ]) + .as('tre'), + 'tre.thread_response_id', + 'tr.id', + ) + .select('*') + .where('tr.thread_id', threadId) + .then((results) => results.map(this.transformFromDBData)); + } return this.knex('thread_response as tr') .join( this.knex(this.tableName) - .distinctOn('thread_response_id') - .select('id', 'thread_response_id', 'detail', 'error', 'created_at') - .orderBy([ - 'thread_response_id', - { column: 'created_at', order: 'desc' }, - ]) + .select() + .whereIn( + ['thread_response_id', 'id'], + this.knex(this.tableName) + .select('thread_response_id') + .max('id') + .groupBy('thread_response_id'), + ) .as('tre'), - 'tre.thread_response_id', 'tr.id', + 'tre.thread_response_id', ) - .select('*') + .select('tre.*') .where('tr.thread_id', threadId) .then((results) => results.map(this.transformFromDBData)); } From 91015ac43e7b222124ca7f12ede341d97a996074 Mon Sep 17 00:00:00 2001 From: andreashimin Date: Fri, 19 Jul 2024 17:27:16 +0800 Subject: [PATCH 09/13] fix(wren-ui): fix vulnerability issue when switching page fast --- wren-ui/src/components/pages/home/thread/index.tsx | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/wren-ui/src/components/pages/home/thread/index.tsx b/wren-ui/src/components/pages/home/thread/index.tsx index 791a18beb..2d5a2a179 100644 --- a/wren-ui/src/components/pages/home/thread/index.tsx +++ b/wren-ui/src/components/pages/home/thread/index.tsx @@ -81,7 +81,8 @@ export default function Thread(props: Props) { const divRef = useRef(null); const triggerScrollToBottom = () => { - const contentLayout = divRef.current.parentElement; + const contentLayout = divRef.current?.parentElement; + if (!contentLayout) return; const lastChild = divRef.current.lastElementChild as HTMLElement; const lastChildElement = lastChild.lastElementChild as HTMLElement; From b366186094fd89471ec16280e63893d5f7bd6fe0 Mon Sep 17 00:00:00 2001 From: andreashimin Date: Mon, 22 Jul 2024 10:04:23 +0800 Subject: [PATCH 10/13] fix(wren-ui): create explanation if found any unstarted explanations --- wren-ui/src/pages/home/[id].tsx | 49 +++++++++++++++++++++++---------- 1 file changed, 35 insertions(+), 14 deletions(-) diff --git a/wren-ui/src/pages/home/[id].tsx b/wren-ui/src/pages/home/[id].tsx index f6e4cf47b..2d7f20fb2 100644 --- a/wren-ui/src/pages/home/[id].tsx +++ b/wren-ui/src/pages/home/[id].tsx @@ -1,6 +1,6 @@ import { useRouter } from 'next/router'; import { useParams } from 'next/navigation'; -import { useEffect, useMemo } from 'react'; +import { useCallback, useEffect, useMemo } from 'react'; import { message } from 'antd'; import { Path } from '@/utils/enum'; import useHomeSidebar from '@/hooks/useHomeSidebar'; @@ -25,6 +25,7 @@ import { CreateCorrectedThreadResponseInput, CreateThreadResponseExplainWhereInput, CreateThreadResponseInput, + ThreadResponse, } from '@/apollo/client/graphql/__types__'; export default function HomeThread() { @@ -107,14 +108,43 @@ export default function HomeThread() { [threadResponse], ); + const startThreadResponseExplanation = useCallback( + (threadResponse: ThreadResponse) => { + const isSuccessBreakdown = threadResponse?.error === null; + const isExplainable = + threadResponse?.explain && + threadResponse?.explain?.error === null && + threadResponse?.explain.queryId === null; + if (isSuccessBreakdown && isExplainable) { + createThreadResponseExplain({ + variables: { where: { responseId: threadResponse.id } }, + }).then(() => + fetchThreadResponse({ variables: { responseId: threadResponse.id } }), + ); + } + }, + [], + ); + useEffect(() => { - const unfinishedRespose = (thread?.responses || []).find( + const unfinishedResposes = (thread?.responses || []).filter( (response) => !getIsFinished(response.status, checkExplainExisted(response?.explain)), ); + if (!!unfinishedResposes.length) { + for (const response of unfinishedResposes) { + fetchThreadResponse({ variables: { responseId: response.id } }); + } + // Explanation will be triggered by polling process + return; + } - if (unfinishedRespose) { - fetchThreadResponse({ variables: { responseId: unfinishedRespose.id } }); + // If all responses are finished, we need to check if there is any explanation that is not started + const unfinishedExplanations = (thread?.responses || []).filter( + (response) => response.explain?.queryId === null, + ); + for (const response of unfinishedExplanations) { + startThreadResponseExplanation(response); } }, [thread]); @@ -122,16 +152,7 @@ export default function HomeThread() { if (isFinished) { threadResponseResult.stopPolling(); - const isSuccessBreakdown = threadResponse?.error === null; - const isExplainable = - threadResponse?.explain && - threadResponse?.explain?.error === null && - threadResponse?.explain.queryId === null; - if (isSuccessBreakdown && isExplainable) { - createThreadResponseExplain({ - variables: { where: { responseId: threadResponse.id } }, - }).then(() => threadResponseResult.startPolling(1000)); - } + startThreadResponseExplanation(threadResponse); } }, [isFinished, threadResponse]); From 6782f3b15927477b32ae91ae83384ad449feacfd Mon Sep 17 00:00:00 2001 From: Shimin Date: Tue, 23 Jul 2024 16:08:15 +0800 Subject: [PATCH 11/13] feat(wren-ui): SQL Highlight References (#548) * feat(wren-ui): add sql highlight component, add format sql in asking service * feat(wren-ui): add two-way binding hovered highlight * fix(wren-ui): modify debug function * fix(wren-ui): user selection issue * fix(wren-ui): omit stepIndex type on collapseContentProps --- .../apollo/server/services/askingService.ts | 7 +- wren-ui/src/components/editor/CodeBlock.tsx | 32 ++- .../pages/home/thread/CollapseContent.tsx | 34 ++- .../pages/home/thread/StepContent.tsx | 3 +- .../thread/feedback/ReferenceSideFloat.tsx | 99 +++++++- .../home/thread/feedback/SQLHighlight.tsx | 215 ++++++++++++++++++ .../pages/home/thread/feedback/index.tsx | 21 +- .../components/pages/home/thread/index.tsx | 1 + 8 files changed, 390 insertions(+), 22 deletions(-) create mode 100644 wren-ui/src/components/pages/home/thread/feedback/SQLHighlight.tsx diff --git a/wren-ui/src/apollo/server/services/askingService.ts b/wren-ui/src/apollo/server/services/askingService.ts index fa040e4d5..2dd6d2e9b 100644 --- a/wren-ui/src/apollo/server/services/askingService.ts +++ b/wren-ui/src/apollo/server/services/askingService.ts @@ -630,7 +630,12 @@ export class AskingService implements IAskingService { const project = await this.projectService.getCurrentProject(); const deployment = await this.deployService.getLastDeployment(project.id); const manifest = deployment.manifest; - const sqls = threadResponse.detail?.steps?.map((step) => step.sql); + const sqls = threadResponse.detail?.steps?.map((step, index) => { + const isLastStep = index === threadResponse.detail.steps.length - 1; + return isLastStep + ? format(constructCteSql(threadResponse.detail.steps)) + : format(step.sql); + }); const analysis = await this.ibisAdaptor.analysisSqls(manifest, sqls); return addAutoIncrementId(analysis); } diff --git a/wren-ui/src/components/editor/CodeBlock.tsx b/wren-ui/src/components/editor/CodeBlock.tsx index 169fa0e2c..77bd588a4 100644 --- a/wren-ui/src/components/editor/CodeBlock.tsx +++ b/wren-ui/src/components/editor/CodeBlock.tsx @@ -1,4 +1,4 @@ -import { useEffect } from 'react'; +import React, { useEffect } from 'react'; import { Typography } from 'antd'; import styled from 'styled-components'; import '@/components/editor/AceEditor'; @@ -19,16 +19,19 @@ const Block = styled.div<{ inline?: boolean; maxHeight?: string }>` : `background: var(--gray-1); padding: 8px;`} .adm-code-wrap { + position: relative; + padding-bottom: 2px; ${(props) => (props.inline ? '' : 'overflow: auto;')} ${(props) => (props.maxHeight ? `max-height: ${props.maxHeight}px;` : ``)} } .adm-code-line { display: block; + height: 22px; &-number { user-select: none; display: inline-block; - min-width: 14px; + min-width: 17px; text-align: right; margin-right: 1em; color: var(--gray-6); @@ -55,6 +58,7 @@ interface Props { loading?: boolean; maxHeight?: string; showLineNumbers?: boolean; + highlightSlot?: React.ReactNode; } const addThemeStyleManually = (cssText) => { @@ -69,21 +73,37 @@ const addThemeStyleManually = (cssText) => { } }; -export default function CodeBlock(props: Props) { - const { code, copyable, maxHeight, inline, loading, showLineNumbers } = props; +export const getTokenizer = () => { const { ace } = window as any; const { Tokenizer } = ace.require('ace/tokenizer'); const { SqlHighlightRules } = ace.require(`ace/mode/sql_highlight_rules`); const rules = new SqlHighlightRules(); const tokenizer = new Tokenizer(rules.getRules()); + return (line) => { + return tokenizer.getLineTokens(line).tokens; + }; +}; + +export default function CodeBlock(props: Props) { + const { + code, + copyable, + maxHeight, + inline, + loading, + showLineNumbers, + highlightSlot, + } = props; useEffect(() => { + const { ace } = window as any; const { cssText } = ace.require('ace/theme/tomorrow'); addThemeStyleManually(cssText); }, []); + const tokenize = getTokenizer(); const lines = (code || '').split('\n').map((line, index) => { - const tokens = tokenizer.getLineTokens(line).tokens; + const tokens = tokenize(line); const children = tokens.map((token, index) => { const classNames = token.type.split('.').map((name) => `ace_${name}`); return ( @@ -92,7 +112,6 @@ export default function CodeBlock(props: Props) { ); }); - return ( {showLineNumbers && ( @@ -112,6 +131,7 @@ export default function CodeBlock(props: Props) {
{lines} + {highlightSlot} {copyable && {code}}
diff --git a/wren-ui/src/components/pages/home/thread/CollapseContent.tsx b/wren-ui/src/components/pages/home/thread/CollapseContent.tsx index 967ed8d6b..ef1991dc0 100644 --- a/wren-ui/src/components/pages/home/thread/CollapseContent.tsx +++ b/wren-ui/src/components/pages/home/thread/CollapseContent.tsx @@ -1,3 +1,4 @@ +import { useMemo, useState } from 'react'; import Image from 'next/image'; import dynamic from 'next/dynamic'; import { Button, Switch, Typography, Empty } from 'antd'; @@ -10,6 +11,8 @@ import PreviewData from '@/components/dataPreview/PreviewData'; import { PreviewDataMutationResult } from '@/apollo/client/graphql/home.generated'; import { DATA_SOURCE_OPTIONS } from '@/components/pages/setup/utils'; import { NativeSQLResult } from '@/hooks/useNativeSQL'; +import SQLHighlight from '@/components/pages/home/thread/feedback/SQLHighlight'; +import { useFeedbackContext } from './feedback'; const CodeBlock = dynamic(() => import('@/components/editor/CodeBlock'), { ssr: false, @@ -36,6 +39,7 @@ export interface Props { onCloseCollapse: () => void; onCopyFullSQL?: () => void; sql: string; + stepIndex: number; previewDataResult: PreviewDataMutationResult; attributes: { stepNumber: number; @@ -53,6 +57,7 @@ export default function CollapseContent(props: Props) { onCloseCollapse, onCopyFullSQL, sql, + stepIndex, previewDataResult, attributes, onChangeNativeSQL, @@ -61,11 +66,27 @@ export default function CollapseContent(props: Props) { const { hasNativeSQL, dataSourceType } = nativeSQLResult; const showNativeSQL = Boolean(attributes.isLastStep) && hasNativeSQL; + const [isNativeSQL, setIsNativeSQL] = useState(false); + const { references, sqlTargetReference, onHighlightToReferences } = + useFeedbackContext(); + const currentStepReferences = useMemo(() => { + return (references || []).filter((item) => item.stepIndex === stepIndex); + }, [references]); const sqls = nativeSQLResult.nativeSQLMode && nativeSQLResult.loading === false ? nativeSQLResult.data : sql; + const hasReferences = references && references.length > 0; + + const onSwitchChange = (checked: boolean) => { + setIsNativeSQL(checked); + onChangeNativeSQL(checked); + }; + + const onHighlightHover = (reference) => { + onHighlightToReferences && onHighlightToReferences(reference); + }; return ( <> @@ -95,7 +116,7 @@ export default function CollapseContent(props: Props) { unCheckedChildren={} className="mr-2" size="small" - onChange={onChangeNativeSQL} + onChange={onSwitchChange} loading={nativeSQLResult.loading} /> @@ -109,6 +130,17 @@ export default function CollapseContent(props: Props) { showLineNumbers maxHeight="300" loading={nativeSQLResult.loading} + highlightSlot={ + hasReferences && + !isNativeSQL && ( + + ) + } /> )} diff --git a/wren-ui/src/components/pages/home/thread/StepContent.tsx b/wren-ui/src/components/pages/home/thread/StepContent.tsx index 2eb724f49..6f38394ed 100644 --- a/wren-ui/src/components/pages/home/thread/StepContent.tsx +++ b/wren-ui/src/components/pages/home/thread/StepContent.tsx @@ -106,8 +106,9 @@ export default function StepContent(props: Props) { )} + stepIndex={stepIndex} key={`collapse-${stepNumber}`} attributes={{ stepNumber, isLastStep }} /> diff --git a/wren-ui/src/components/pages/home/thread/feedback/ReferenceSideFloat.tsx b/wren-ui/src/components/pages/home/thread/feedback/ReferenceSideFloat.tsx index 6a3cc3a57..42a0fab95 100644 --- a/wren-ui/src/components/pages/home/thread/feedback/ReferenceSideFloat.tsx +++ b/wren-ui/src/components/pages/home/thread/feedback/ReferenceSideFloat.tsx @@ -1,6 +1,13 @@ import clsx from 'clsx'; import { groupBy } from 'lodash'; -import { useMemo, useState } from 'react'; +import { + useMemo, + useState, + forwardRef, + useImperativeHandle, + useRef, + useEffect, +} from 'react'; import styled from 'styled-components'; import { Tag, Typography, Button, Input, Alert } from 'antd'; import { @@ -22,6 +29,15 @@ import { ERROR_CODES } from '@/utils/errorHandler'; const StyledReferenceSideFloat = styled.div` position: relative; + .referenceSideFloat__summary { + cursor: pointer; + padding: 2px 0; + &:hover, + &.isActive { + background-color: rgba(250, 219, 20, 0.3); + } + } + .referenceSideFloat-title { position: absolute; top: -14px; @@ -54,6 +70,7 @@ interface Props { error?: Record; onSaveCorrectionPrompt: (id: string, value: string) => void; onTriggerExplanation: () => void; + onHoverReference?: (reference: Reference) => void; } const ReferenceSummaryTemplate = ({ @@ -82,6 +99,7 @@ const GroupReferenceTemplate = ({ data, index, saveCorrectionPrompt, + hoverReference, }) => { if (!data.length) return null; return ( @@ -93,19 +111,19 @@ const GroupReferenceTemplate = ({
); }; const ReferenceTemplate = ({ - type, - summary, - referenceId, - referenceNum, - correctionPrompt, saveCorrectionPrompt, + hoverReference, + ...reference }) => { + const { type, summary, referenceId, referenceNum, correctionPrompt } = + reference; const [isEdit, setIsEdit] = useState(false); const [value, setValue] = useState(correctionPrompt); const isRevise = !!correctionPrompt; @@ -130,7 +148,16 @@ const ReferenceTemplate = ({
- {summary} + hoverReference(reference)} + onMouseLeave={() => hoverReference()} + > + {summary} + {isRevise ? ( '(feedback suggested)' @@ -170,8 +197,14 @@ const ReferenceSummaryIterator = makeIterable(ReferenceSummaryTemplate); const GroupReferenceIterator = makeIterable(GroupReferenceTemplate); const ReferenceIterator = makeIterable(ReferenceTemplate); -const References = (props: Props) => { - const { references, onSaveCorrectionPrompt } = props; +const References = (props: Props & { targetReference?: Reference }) => { + const { + references, + onSaveCorrectionPrompt, + targetReference, + onHoverReference, + } = props; + const $scroller = useRef(null); const referencesByGroup = groupBy(references, 'type'); const resources = REFERENCE_ORDERS.map((type) => ({ type, @@ -179,22 +212,58 @@ const References = (props: Props) => { data: referencesByGroup[type] || [], })); + useEffect(() => { + if ($scroller.current) { + const $element = $scroller.current; + const $targets = $element.querySelectorAll(`.isActive`); + $targets.forEach((target) => { + target.classList.remove('isActive'); + }); + + if (targetReference) { + const $target = $element.querySelector( + `.reference-${targetReference.referenceNum}`, + ); + $target.classList.add('isActive'); + $element.scrollTo({ top: $target.offsetTop - 100 }); + } + } + }, [targetReference]); + + const hoverReference = (reference?: Reference) => { + onHoverReference && onHoverReference(reference); + }; + return (
); }; -export default function ReferenceSideFloat(props: Props) { - const { references, error, onTriggerExplanation } = props; +function ReferenceSideFloat(props: Props, ref) { + const { references, error, onTriggerExplanation, onHoverReference } = props; const [collapse, setCollapse] = useState(false); + const [targetReference, setTargetReference] = useState( + null, + ); + + const triggerHighlight = (reference: Reference) => { + setCollapse(true); + setTargetReference(reference); + }; + + useImperativeHandle(ref, () => ({ + triggerHighlight, + })); const referencesSummary = useMemo( () => @@ -245,7 +314,11 @@ export default function ReferenceSideFloat(props: Props) { References
{collapse ? ( - + ) : ( <> @@ -262,3 +335,5 @@ export default function ReferenceSideFloat(props: Props) { ); } + +export default forwardRef(ReferenceSideFloat); diff --git a/wren-ui/src/components/pages/home/thread/feedback/SQLHighlight.tsx b/wren-ui/src/components/pages/home/thread/feedback/SQLHighlight.tsx new file mode 100644 index 000000000..6bf32e5ce --- /dev/null +++ b/wren-ui/src/components/pages/home/thread/feedback/SQLHighlight.tsx @@ -0,0 +1,215 @@ +import { useEffect, useMemo, useRef } from 'react'; +import clsx from 'clsx'; +import { Tag } from 'antd'; +import { groupBy } from 'lodash'; +import styled from 'styled-components'; +import { getReferenceIcon, Reference } from './utils'; +import { getTokenizer } from '@/components/editor/CodeBlock'; + +const SQLWrapper = styled.div` + position: absolute; + top: 0; + left: 28px; + right: 0; + z-index: 1; + font-size: 14px; + color: var(--gray-9); + margin: 0 3px; + + .sqlHighlight__line { + background-color: white; + height: 22px; + } + + .sqlHighlight__block { + position: relative; + &:hover, + &.isActive { + mark { + background-color: rgba(250, 219, 20, 0.3); + } + } + } + + mark { + cursor: pointer; + position: relative; + color: currentColor; + background-color: transparent; + border-bottom: 1px dashed var(--gray-5); + padding: 2px 0; + } + + .sqlHighlight__tags { + user-select: none; + padding: 0 4px; + + &:after { + content: ''; + vertical-align: middle; + } + + .ant-tag { + cursor: pointer; + margin-right: 0; + vertical-align: middle; + + .ant-tag { + margin-left: 4px; + } + } + } +`; + +interface Props { + sql: string; + references: Reference[]; + targetReference?: Reference; + onHighlightHover?: (reference: Reference) => void; +} + +const optimizedSnippet = (snippet: string) => { + // SQL analysis may add more spaces and add brackets to the sql, so we need to handle it. + return snippet + .replace(/\(/g, '\\(?') + .replace(/\)/g, '\\)?') + .replace(/\s/g, '\\s*'); +}; + +const createSnippetsRegex = (snippets: string[]) => { + return new RegExp(`(${snippets.join('|')})`, 'gi'); +}; + +const _printUnmatchedReferences = ( + references: Reference[], + referenceMatches, +) => { + // For debugging purpose + const matchesReferences = referenceMatches.flat(); + const unmatchedReferences = references.filter( + (reference) => + !matchesReferences.find((r) => r.referenceNum === reference.referenceNum), + ); + if (unmatchedReferences.length > 0) + console.log('Unmatched references:', unmatchedReferences); +}; + +export default function SQLHighlight(props: Props) { + const { sql, references, targetReference, onHighlightHover } = props; + const $wrapper = useRef(null); + + useEffect(() => { + if ($wrapper.current) { + const $element = $wrapper.current; + const $targets = $element.querySelectorAll(`.isActive`); + $targets.forEach((target) => { + target.classList.remove('isActive'); + }); + if (targetReference) { + const $target = $wrapper.current.querySelector( + `.reference-${targetReference.referenceNum}`, + ); + if (!$target) return; + $target.classList.add('isActive'); + } + } + }, [targetReference]); + + const sqlArray = useMemo(() => sql.split('\n'), [sql]); + const referenceGroups = useMemo(() => { + const filteredReferences = references + .filter((reference) => reference.sqlLocation) + .map((reference) => ({ + ...reference, + sqlSnippet: reference.sqlSnippet + ? optimizedSnippet(reference.sqlSnippet) + : reference.sqlSnippet, + })); + return groupBy( + filteredReferences, + (reference) => reference.sqlLocation.line, + ); + }, [references]); + + const hoverHighlight = (reference?: Reference) => { + onHighlightHover && onHighlightHover(reference); + }; + const highlights = []; + const referenceMatches = []; + const tokenize = getTokenizer(); + Object.keys(referenceGroups).forEach((line) => { + const lineIndex = Number(line) - 1; + const lineReferences = referenceGroups[line]; + const snippets = lineReferences.map((r) => r.sqlSnippet); + const regex = createSnippetsRegex(snippets); + const parts = sqlArray[lineIndex].split(regex); + + // Add to highlights if the part is matched + highlights[lineIndex] = parts.map((part, index) => { + const tokens = tokenize(part); + const tokenizedPart = tokens.map((token, tokenIndex) => { + const classNames = token.type.split('.').map((name) => `ace_${name}`); + return ( + + {token.value} + + ); + }); + if (regex.test(part)) { + const matchedReferences = lineReferences.filter((reference) => + new RegExp(reference.sqlSnippet, 'i').test(part), + ); + const tags = matchedReferences.map((reference) => { + return ( + + + {getReferenceIcon(reference.type)} + + {reference.referenceNum} + + ); + }); + // Record the matched references + referenceMatches.push(matchedReferences); + const reference = matchedReferences[0]; + return ( + hoverHighlight(reference)} + onMouseLeave={() => hoverHighlight()} + key={index} + > + {tokenizedPart} + {tags && {tags}} + + ); + } + return {tokenizedPart}; + }); + }); + + const content = sqlArray.map((line, index) => { + if (highlights[index]) { + return ( +
+ {highlights[index]} +
+ ); + } + return ( +
+ {line} +
+ ); + }); + + // For debugging purpose + // _printUnmatchedReferences(references, referenceMatches); + + return {content}; +} diff --git a/wren-ui/src/components/pages/home/thread/feedback/index.tsx b/wren-ui/src/components/pages/home/thread/feedback/index.tsx index 04f1bbd85..814f9a40f 100644 --- a/wren-ui/src/components/pages/home/thread/feedback/index.tsx +++ b/wren-ui/src/components/pages/home/thread/feedback/index.tsx @@ -1,4 +1,4 @@ -import { createContext, useContext, useMemo, useState } from 'react'; +import { createContext, useContext, useMemo, useRef, useState } from 'react'; import { sortBy } from 'lodash'; import { Skeleton } from 'antd'; import ReferenceSideFloat from '@/components/pages/home/thread/feedback/ReferenceSideFloat'; @@ -11,10 +11,14 @@ import { getIsExplainFinished } from '@/hooks/useAskPrompt'; type ContextProps = { references: Reference[]; + sqlTargetReference?: Reference; + onHighlightToReferences: (target?: Reference) => void; } | null; export const FeedbackContext = createContext({ references: [], + sqlTargetReference: null, + onHighlightToReferences: () => {}, }); export const useFeedbackContext = () => { @@ -38,6 +42,9 @@ export default function Feedback(props: Props) { onTriggerThreadResponseExplain, } = props; + const reviewSideFloat = useRef(null); + const [sqlTargetReference, setSqlTargetReference] = + useState(); const [correctionPrompts, setCorrectionPrompts] = useState({}); const reviewDrawer = useDrawerAction(); @@ -57,6 +64,10 @@ export default function Feedback(props: Props) { onTriggerThreadResponseExplain({ responseId: threadResponse.id }); }; + const hoverReference = (reference?: Reference) => { + setSqlTargetReference(reference); + }; + const loading = useMemo( () => !getIsExplainFinished(threadResponse?.explain?.status), [threadResponse?.explain?.status], @@ -85,6 +96,12 @@ export default function Feedback(props: Props) { const contextValue = { references, + sqlTargetReference, + onHighlightToReferences: (target) => { + if (reviewSideFloat.current) { + reviewSideFloat.current?.triggerHighlight(target); + } + }, }; return ( @@ -105,10 +122,12 @@ export default function Feedback(props: Props) {
diff --git a/wren-ui/src/components/pages/home/thread/index.tsx b/wren-ui/src/components/pages/home/thread/index.tsx index 2d5a2a179..c0c681084 100644 --- a/wren-ui/src/components/pages/home/thread/index.tsx +++ b/wren-ui/src/components/pages/home/thread/index.tsx @@ -107,6 +107,7 @@ export default function Thread(props: Props) { record.id} onOpenSaveAsViewModal={onOpenSaveAsViewModal} onTriggerScrollToBottom={triggerScrollToBottom} onSubmitReviewDrawer={onSubmitReviewDrawer} From 62d714884ecece539cb8d5593979f560b480cd75 Mon Sep 17 00:00:00 2001 From: onlyjackfrost Date: Wed, 24 Jul 2024 16:08:16 +0800 Subject: [PATCH 12/13] - add cte_name when asking ai service to explain sql --- wren-ui/src/apollo/server/adaptors/wrenAIAdaptor.ts | 5 +++++ wren-ui/src/apollo/server/services/askingService.ts | 1 + 2 files changed, 6 insertions(+) diff --git a/wren-ui/src/apollo/server/adaptors/wrenAIAdaptor.ts b/wren-ui/src/apollo/server/adaptors/wrenAIAdaptor.ts index fd43d2002..0a9c7535a 100644 --- a/wren-ui/src/apollo/server/adaptors/wrenAIAdaptor.ts +++ b/wren-ui/src/apollo/server/adaptors/wrenAIAdaptor.ts @@ -76,6 +76,7 @@ export enum AskResultStatus { export interface StepAnalysisResult { sql: string; summary: string; + cte_name?: string; sql_analysis_results: any; } @@ -254,6 +255,10 @@ export class WrenAIAdaptor implements IWrenAIAdaptor { stepAnalysisResult: StepAnalysisResult[], ): Promise { try { + logger.info({ + question, + steps_with_analysis_results: stepAnalysisResult, + }); const res = await axios.post( `${this.wrenAIBaseEndpoint}/v1/sql-explanations`, { diff --git a/wren-ui/src/apollo/server/services/askingService.ts b/wren-ui/src/apollo/server/services/askingService.ts index 2dd6d2e9b..39fc723b6 100644 --- a/wren-ui/src/apollo/server/services/askingService.ts +++ b/wren-ui/src/apollo/server/services/askingService.ts @@ -501,6 +501,7 @@ export class AskingService implements IAskingService { return { sql: step.sql, summary: step.summary, + cte_name: step.cteName, sql_analysis_results: analysisWithIds[idx], } as StepAnalysisResult; }, From 9468b131b038438f0f9fedbe545830d7c02a01ac Mon Sep 17 00:00:00 2001 From: onlyjackfrost Date: Thu, 25 Jul 2024 10:20:51 +0800 Subject: [PATCH 13/13] use different nodeLocation with analysis type is relationship --- .../server/backgroundTrackers/explainBackgroundTracker.ts | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/wren-ui/src/apollo/server/backgroundTrackers/explainBackgroundTracker.ts b/wren-ui/src/apollo/server/backgroundTrackers/explainBackgroundTracker.ts index c2eaa7ca2..a37c38c44 100644 --- a/wren-ui/src/apollo/server/backgroundTrackers/explainBackgroundTracker.ts +++ b/wren-ui/src/apollo/server/backgroundTrackers/explainBackgroundTracker.ts @@ -17,6 +17,7 @@ import { GeneralErrorCodes } from '../utils/error'; import { findAnalysisById, reverseEnum } from '../utils'; import { BackgroundTracker } from './index'; import { getLogger } from '@server/utils/logger'; +import { RelationType } from '../adaptors/ibisAdaptor'; const logger = getLogger('ExplainBackgroundTracker'); logger.level = 'debug'; @@ -211,6 +212,9 @@ export class ThreadResponseExplainBackgroundTracker extends BackgroundTracker