From ad3b9880c792833e7590a60d57b65e08ecbd9b25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B8ren=20Louv-Jansen?= Date: Wed, 8 Jan 2025 19:59:20 +0100 Subject: [PATCH] [Obs AI Assistant] Use architecture-specific elser model (#205851) Closes https://github.com/elastic/kibana/issues/205852 When installing the Obs knowledge base it will always install the model `.elser_model_2`. For Linux with an x86-64 CPU an optimised version of Elser exists (`elser_model_2_linux-x86_64`). We should use that when possible. After this change the inference endpoint will use `.elser_model_2_linux-x86_64` on supported hardware: ![image](https://github.com/user-attachments/assets/fedc6700-877a-47ab-a3b8-055db53407d0) --- .../server/service/client/index.ts | 5 ++ .../server/service/inference_endpoint.ts | 4 +- .../get_elser_model_id.ts | 53 +++++++++++++++++++ .../service/knowledge_base_service/index.ts | 2 +- .../recall_from_search_connectors.ts | 43 +-------------- 5 files changed, 63 insertions(+), 44 deletions(-) create mode 100644 x-pack/platform/plugins/shared/observability_solution/observability_ai_assistant/server/service/knowledge_base_service/get_elser_model_id.ts diff --git a/x-pack/platform/plugins/shared/observability_solution/observability_ai_assistant/server/service/client/index.ts b/x-pack/platform/plugins/shared/observability_solution/observability_ai_assistant/server/service/client/index.ts index 8d1ee6138e54f..95856768c4c5d 100644 --- a/x-pack/platform/plugins/shared/observability_solution/observability_ai_assistant/server/service/client/index.ts +++ b/x-pack/platform/plugins/shared/observability_solution/observability_ai_assistant/server/service/client/index.ts @@ -80,6 +80,7 @@ import { } from '../task_manager_definitions/register_migrate_knowledge_base_entries_task'; import { ObservabilityAIAssistantPluginStartDependencies } from '../../types'; import { ObservabilityAIAssistantConfig } from '../../config'; +import { getElserModelId } from '../knowledge_base_service/get_elser_model_id'; const MAX_FUNCTION_CALLS = 8; @@ -660,6 +661,10 @@ export class ObservabilityAIAssistantClient { setupKnowledgeBase = async (modelId: string | undefined) => { const { esClient, core, logger, knowledgeBaseService } = this.dependencies; + if (!modelId) { + modelId = await getElserModelId({ core, logger }); + } + // setup the knowledge base const res = await knowledgeBaseService.setup(esClient, modelId); diff --git a/x-pack/platform/plugins/shared/observability_solution/observability_ai_assistant/server/service/inference_endpoint.ts b/x-pack/platform/plugins/shared/observability_solution/observability_ai_assistant/server/service/inference_endpoint.ts index a2993f7353c61..1822b7766b0b7 100644 --- a/x-pack/platform/plugins/shared/observability_solution/observability_ai_assistant/server/service/inference_endpoint.ts +++ b/x-pack/platform/plugins/shared/observability_solution/observability_ai_assistant/server/service/inference_endpoint.ts @@ -16,13 +16,13 @@ export const AI_ASSISTANT_KB_INFERENCE_ID = 'obs_ai_assistant_kb_inference'; export async function createInferenceEndpoint({ esClient, logger, - modelId = '.elser_model_2', + modelId, }: { esClient: { asCurrentUser: ElasticsearchClient; }; logger: Logger; - modelId: string | undefined; + modelId: string; }) { try { logger.debug(`Creating inference endpoint "${AI_ASSISTANT_KB_INFERENCE_ID}"`); diff --git a/x-pack/platform/plugins/shared/observability_solution/observability_ai_assistant/server/service/knowledge_base_service/get_elser_model_id.ts b/x-pack/platform/plugins/shared/observability_solution/observability_ai_assistant/server/service/knowledge_base_service/get_elser_model_id.ts new file mode 100644 index 0000000000000..99f4ceb6c247f --- /dev/null +++ b/x-pack/platform/plugins/shared/observability_solution/observability_ai_assistant/server/service/knowledge_base_service/get_elser_model_id.ts @@ -0,0 +1,53 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import type { Logger } from '@kbn/logging'; +import { CoreSetup } from '@kbn/core-lifecycle-server'; +import { firstValueFrom } from 'rxjs'; +import { ObservabilityAIAssistantPluginStartDependencies } from '../../types'; + +export async function getElserModelId({ + core, + logger, +}: { + core: CoreSetup; + logger: Logger; +}) { + const defaultModelId = '.elser_model_2'; + const [_, pluginsStart] = await core.getStartServices(); + + // Wait for the license to be available so the ML plugin's guards pass once we ask for ELSER stats + const license = await firstValueFrom(pluginsStart.licensing.license$); + if (!license.hasAtLeast('enterprise')) { + return defaultModelId; + } + + try { + // Wait for the ML plugin's dependency on the internal saved objects client to be ready + const { ml } = await core.plugins.onSetup<{ + ml: { + trainedModelsProvider: ( + request: {}, + soClient: {} + ) => { getELSER: () => Promise<{ model_id: string }> }; + }; + }>('ml'); + + if (!ml.found) { + throw new Error('Could not find ML plugin'); + } + + const elserModelDefinition = await ml.contract + .trainedModelsProvider({} as any, {} as any) // request, savedObjectsClient (but we fake it to use the internal user) + .getELSER(); + + return elserModelDefinition.model_id; + } catch (error) { + logger.error(`Failed to resolve ELSER model definition: ${error}`); + return defaultModelId; + } +} diff --git a/x-pack/platform/plugins/shared/observability_solution/observability_ai_assistant/server/service/knowledge_base_service/index.ts b/x-pack/platform/plugins/shared/observability_solution/observability_ai_assistant/server/service/knowledge_base_service/index.ts index bb77dfc768d95..cba83f715ff61 100644 --- a/x-pack/platform/plugins/shared/observability_solution/observability_ai_assistant/server/service/knowledge_base_service/index.ts +++ b/x-pack/platform/plugins/shared/observability_solution/observability_ai_assistant/server/service/knowledge_base_service/index.ts @@ -58,7 +58,7 @@ export class KnowledgeBaseService { asCurrentUser: ElasticsearchClient; asInternalUser: ElasticsearchClient; }, - modelId: string | undefined + modelId: string ) { await deleteInferenceEndpoint({ esClient }).catch((e) => {}); // ensure existing inference endpoint is deleted return createInferenceEndpoint({ esClient, logger: this.dependencies.logger, modelId }); diff --git a/x-pack/platform/plugins/shared/observability_solution/observability_ai_assistant/server/service/knowledge_base_service/recall_from_search_connectors.ts b/x-pack/platform/plugins/shared/observability_solution/observability_ai_assistant/server/service/knowledge_base_service/recall_from_search_connectors.ts index b8a0a7d9267bc..3001a28f6dbbb 100644 --- a/x-pack/platform/plugins/shared/observability_solution/observability_ai_assistant/server/service/knowledge_base_service/recall_from_search_connectors.ts +++ b/x-pack/platform/plugins/shared/observability_solution/observability_ai_assistant/server/service/knowledge_base_service/recall_from_search_connectors.ts @@ -10,10 +10,10 @@ import { IUiSettingsClient } from '@kbn/core-ui-settings-server'; import { isEmpty, orderBy, compact } from 'lodash'; import type { Logger } from '@kbn/logging'; import { CoreSetup } from '@kbn/core-lifecycle-server'; -import { firstValueFrom } from 'rxjs'; import { RecalledEntry } from '.'; import { aiAssistantSearchConnectorIndexPattern } from '../../../common'; import { ObservabilityAIAssistantPluginStartDependencies } from '../../types'; +import { getElserModelId } from './get_elser_model_id'; export async function recallFromSearchConnectors({ queries, @@ -128,7 +128,7 @@ async function recallFromLegacyConnectors({ }): Promise { const ML_INFERENCE_PREFIX = 'ml.inference.'; - const modelIdPromise = getElserModelId(core, logger); // pre-fetch modelId in parallel with fieldCaps + const modelIdPromise = getElserModelId({ core, logger }); // pre-fetch modelId in parallel with fieldCaps const fieldCaps = await esClient.asCurrentUser.fieldCaps({ index: connectorIndices, fields: `${ML_INFERENCE_PREFIX}*`, @@ -230,42 +230,3 @@ async function getConnectorIndices( return connectorIndices; } - -async function getElserModelId( - core: CoreSetup, - logger: Logger -) { - const defaultModelId = '.elser_model_2'; - const [_, pluginsStart] = await core.getStartServices(); - - // Wait for the license to be available so the ML plugin's guards pass once we ask for ELSER stats - const license = await firstValueFrom(pluginsStart.licensing.license$); - if (!license.hasAtLeast('enterprise')) { - return defaultModelId; - } - - try { - // Wait for the ML plugin's dependency on the internal saved objects client to be ready - const { ml } = await core.plugins.onSetup('ml'); - - if (!ml.found) { - throw new Error('Could not find ML plugin'); - } - - const elserModelDefinition = await ( - ml.contract as { - trainedModelsProvider: ( - request: {}, - soClient: {} - ) => { getELSER: () => Promise<{ model_id: string }> }; - } - ) - .trainedModelsProvider({} as any, {} as any) // request, savedObjectsClient (but we fake it to use the internal user) - .getELSER(); - - return elserModelDefinition.model_id; - } catch (error) { - logger.error(`Failed to resolve ELSER model definition: ${error}`); - return defaultModelId; - } -}