Skip to content

Commit

Permalink
[Obs AI Assistant] Use architecture-specific elser model (elastic#205851
Browse files Browse the repository at this point in the history
)

Closes elastic#205852

When installing the Obs knowledge base it will always install the model
`.elser_model_2`.
For Linux with an x86-64 CPU an optimised version of Elser exists
(`elser_model_2_linux-x86_64`). We should use that when possible.

After this change the inference endpoint will use
`.elser_model_2_linux-x86_64` on supported hardware:

![image](https://github.com/user-attachments/assets/fedc6700-877a-47ab-a3b8-055db53407d0)
  • Loading branch information
sorenlouv authored and viduni94 committed Jan 23, 2025
1 parent 60e24f1 commit 71f2831
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ import {
} from '../task_manager_definitions/register_migrate_knowledge_base_entries_task';
import { ObservabilityAIAssistantPluginStartDependencies } from '../../types';
import { ObservabilityAIAssistantConfig } from '../../config';
import { getElserModelId } from '../knowledge_base_service/get_elser_model_id';

const MAX_FUNCTION_CALLS = 8;

Expand Down Expand Up @@ -660,6 +661,10 @@ export class ObservabilityAIAssistantClient {
setupKnowledgeBase = async (modelId: string | undefined) => {
const { esClient, core, logger, knowledgeBaseService } = this.dependencies;

if (!modelId) {
modelId = await getElserModelId({ core, logger });
}

// setup the knowledge base
const res = await knowledgeBaseService.setup(esClient, modelId);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ export const AI_ASSISTANT_KB_INFERENCE_ID = 'obs_ai_assistant_kb_inference';
export async function createInferenceEndpoint({
esClient,
logger,
modelId = '.elser_model_2',
modelId,
}: {
esClient: {
asCurrentUser: ElasticsearchClient;
};
logger: Logger;
modelId: string | undefined;
modelId: string;
}) {
try {
logger.debug(`Creating inference endpoint "${AI_ASSISTANT_KB_INFERENCE_ID}"`);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import type { Logger } from '@kbn/logging';
import { CoreSetup } from '@kbn/core-lifecycle-server';
import { firstValueFrom } from 'rxjs';
import { ObservabilityAIAssistantPluginStartDependencies } from '../../types';

export async function getElserModelId({
core,
logger,
}: {
core: CoreSetup<ObservabilityAIAssistantPluginStartDependencies>;
logger: Logger;
}) {
const defaultModelId = '.elser_model_2';
const [_, pluginsStart] = await core.getStartServices();

// Wait for the license to be available so the ML plugin's guards pass once we ask for ELSER stats
const license = await firstValueFrom(pluginsStart.licensing.license$);
if (!license.hasAtLeast('enterprise')) {
return defaultModelId;
}

try {
// Wait for the ML plugin's dependency on the internal saved objects client to be ready
const { ml } = await core.plugins.onSetup<{
ml: {
trainedModelsProvider: (
request: {},
soClient: {}
) => { getELSER: () => Promise<{ model_id: string }> };
};
}>('ml');

if (!ml.found) {
throw new Error('Could not find ML plugin');
}

const elserModelDefinition = await ml.contract
.trainedModelsProvider({} as any, {} as any) // request, savedObjectsClient (but we fake it to use the internal user)
.getELSER();

return elserModelDefinition.model_id;
} catch (error) {
logger.error(`Failed to resolve ELSER model definition: ${error}`);
return defaultModelId;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ export class KnowledgeBaseService {
asCurrentUser: ElasticsearchClient;
asInternalUser: ElasticsearchClient;
},
modelId: string | undefined
modelId: string
) {
await deleteInferenceEndpoint({ esClient }).catch((e) => {}); // ensure existing inference endpoint is deleted
return createInferenceEndpoint({ esClient, logger: this.dependencies.logger, modelId });
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ import { IUiSettingsClient } from '@kbn/core-ui-settings-server';
import { isEmpty, orderBy, compact } from 'lodash';
import type { Logger } from '@kbn/logging';
import { CoreSetup } from '@kbn/core-lifecycle-server';
import { firstValueFrom } from 'rxjs';
import { RecalledEntry } from '.';
import { aiAssistantSearchConnectorIndexPattern } from '../../../common';
import { ObservabilityAIAssistantPluginStartDependencies } from '../../types';
import { getElserModelId } from './get_elser_model_id';

export async function recallFromSearchConnectors({
queries,
Expand Down Expand Up @@ -128,7 +128,7 @@ async function recallFromLegacyConnectors({
}): Promise<RecalledEntry[]> {
const ML_INFERENCE_PREFIX = 'ml.inference.';

const modelIdPromise = getElserModelId(core, logger); // pre-fetch modelId in parallel with fieldCaps
const modelIdPromise = getElserModelId({ core, logger }); // pre-fetch modelId in parallel with fieldCaps
const fieldCaps = await esClient.asCurrentUser.fieldCaps({
index: connectorIndices,
fields: `${ML_INFERENCE_PREFIX}*`,
Expand Down Expand Up @@ -230,42 +230,3 @@ async function getConnectorIndices(

return connectorIndices;
}

async function getElserModelId(
core: CoreSetup<ObservabilityAIAssistantPluginStartDependencies>,
logger: Logger
) {
const defaultModelId = '.elser_model_2';
const [_, pluginsStart] = await core.getStartServices();

// Wait for the license to be available so the ML plugin's guards pass once we ask for ELSER stats
const license = await firstValueFrom(pluginsStart.licensing.license$);
if (!license.hasAtLeast('enterprise')) {
return defaultModelId;
}

try {
// Wait for the ML plugin's dependency on the internal saved objects client to be ready
const { ml } = await core.plugins.onSetup('ml');

if (!ml.found) {
throw new Error('Could not find ML plugin');
}

const elserModelDefinition = await (
ml.contract as {
trainedModelsProvider: (
request: {},
soClient: {}
) => { getELSER: () => Promise<{ model_id: string }> };
}
)
.trainedModelsProvider({} as any, {} as any) // request, savedObjectsClient (but we fake it to use the internal user)
.getELSER();

return elserModelDefinition.model_id;
} catch (error) {
logger.error(`Failed to resolve ELSER model definition: ${error}`);
return defaultModelId;
}
}

0 comments on commit 71f2831

Please sign in to comment.