Skip to content

Commit

Permalink
[7.x] [ML] Add decision path charts to exploration results table (ela…
Browse files Browse the repository at this point in the history
…stic#73561) (elastic#77082)

Co-authored-by: Elastic Machine <[email protected]>

Co-authored-by: Elastic Machine <[email protected]>
  • Loading branch information
qn895 and elasticmachine authored Sep 9, 2020
1 parent 2b765b1 commit ac5a6ea
Show file tree
Hide file tree
Showing 27 changed files with 1,083 additions and 125 deletions.
7 changes: 7 additions & 0 deletions x-pack/plugins/ml/common/constants/data_frame_analytics.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/

export const DEFAULT_RESULTS_FIELD = 'ml';
6 changes: 6 additions & 0 deletions x-pack/plugins/ml/common/types/data_frame_analytics.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,9 @@ export interface DataFrameAnalyticsConfig {
version: string;
allow_lazy_start?: boolean;
}

export enum ANALYSIS_CONFIG_TYPE {
OUTLIER_DETECTION = 'outlier_detection',
REGRESSION = 'regression',
CLASSIFICATION = 'classification',
}
23 changes: 23 additions & 0 deletions x-pack/plugins/ml/common/types/feature_importance.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/

export interface ClassFeatureImportance {
class_name: string | boolean;
importance: number;
}
export interface FeatureImportance {
feature_name: string;
importance?: number;
classes?: ClassFeatureImportance[];
}

export interface TopClass {
class_name: string;
class_probability: number;
class_score: number;
}

export type TopClasses = TopClass[];
79 changes: 79 additions & 0 deletions x-pack/plugins/ml/common/util/analytics_utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/

import {
AnalysisConfig,
ClassificationAnalysis,
OutlierAnalysis,
RegressionAnalysis,
ANALYSIS_CONFIG_TYPE,
} from '../types/data_frame_analytics';

export const isOutlierAnalysis = (arg: any): arg is OutlierAnalysis => {
const keys = Object.keys(arg);
return keys.length === 1 && keys[0] === ANALYSIS_CONFIG_TYPE.OUTLIER_DETECTION;
};

export const isRegressionAnalysis = (arg: any): arg is RegressionAnalysis => {
const keys = Object.keys(arg);
return keys.length === 1 && keys[0] === ANALYSIS_CONFIG_TYPE.REGRESSION;
};

export const isClassificationAnalysis = (arg: any): arg is ClassificationAnalysis => {
const keys = Object.keys(arg);
return keys.length === 1 && keys[0] === ANALYSIS_CONFIG_TYPE.CLASSIFICATION;
};

export const getDependentVar = (
analysis: AnalysisConfig
):
| RegressionAnalysis['regression']['dependent_variable']
| ClassificationAnalysis['classification']['dependent_variable'] => {
let depVar = '';

if (isRegressionAnalysis(analysis)) {
depVar = analysis.regression.dependent_variable;
}

if (isClassificationAnalysis(analysis)) {
depVar = analysis.classification.dependent_variable;
}
return depVar;
};

export const getPredictionFieldName = (
analysis: AnalysisConfig
):
| RegressionAnalysis['regression']['prediction_field_name']
| ClassificationAnalysis['classification']['prediction_field_name'] => {
// If undefined will be defaulted to dependent_variable when config is created
let predictionFieldName;
if (isRegressionAnalysis(analysis) && analysis.regression.prediction_field_name !== undefined) {
predictionFieldName = analysis.regression.prediction_field_name;
} else if (
isClassificationAnalysis(analysis) &&
analysis.classification.prediction_field_name !== undefined
) {
predictionFieldName = analysis.classification.prediction_field_name;
}
return predictionFieldName;
};

export const getDefaultPredictionFieldName = (analysis: AnalysisConfig) => {
return `${getDependentVar(analysis)}_prediction`;
};
export const getPredictedFieldName = (
resultsField: string,
analysis: AnalysisConfig,
forSort?: boolean
) => {
// default is 'ml'
const predictionFieldName = getPredictionFieldName(analysis);
const predictedField = `${resultsField}.${
predictionFieldName ? predictionFieldName : getDefaultPredictionFieldName(analysis)
}`;
return predictedField;
};
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,14 @@ export const getDataGridSchemasFromFieldTypes = (fieldTypes: FieldTypes, results
schema = 'numeric';
}

if (
field.includes(`${resultsField}.${FEATURE_IMPORTANCE}`) ||
field.includes(`${resultsField}.${TOP_CLASSES}`)
) {
if (field.includes(`${resultsField}.${TOP_CLASSES}`)) {
schema = 'json';
}

if (field.includes(`${resultsField}.${FEATURE_IMPORTANCE}`)) {
schema = 'featureImportance';
}

return { id: field, schema, isSortable };
});
};
Expand Down Expand Up @@ -250,10 +251,6 @@ export const useRenderCellValue = (
return cellValue ? 'true' : 'false';
}

if (typeof cellValue === 'object' && cellValue !== null) {
return JSON.stringify(cellValue);
}

return cellValue;
};
}, [indexPattern?.fields, pagination.pageIndex, pagination.pageSize, tableItems]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
*/

import { isEqual } from 'lodash';
import React, { memo, useEffect, FC } from 'react';

import React, { memo, useEffect, FC, useMemo } from 'react';
import { i18n } from '@kbn/i18n';

import {
Expand All @@ -24,13 +23,16 @@ import {
} from '@elastic/eui';

import { CoreSetup } from 'src/core/public';

import { DEFAULT_SAMPLER_SHARD_SIZE } from '../../../../common/constants/field_histograms';

import { INDEX_STATUS } from '../../data_frame_analytics/common';
import { ANALYSIS_CONFIG_TYPE, INDEX_STATUS } from '../../data_frame_analytics/common';

import { euiDataGridStyle, euiDataGridToolbarSettings } from './common';
import { UseIndexDataReturnType } from './types';
import { DecisionPathPopover } from './feature_importance/decision_path_popover';
import { TopClasses } from '../../../../common/types/feature_importance';
import { DEFAULT_RESULTS_FIELD } from '../../../../common/constants/data_frame_analytics';

// TODO Fix row hovering + bar highlighting
// import { hoveredRow$ } from './column_chart';

Expand All @@ -41,6 +43,9 @@ export const DataGridTitle: FC<{ title: string }> = ({ title }) => (
);

interface PropsWithoutHeader extends UseIndexDataReturnType {
baseline?: number;
analysisType?: ANALYSIS_CONFIG_TYPE;
resultsField?: string;
dataTestSubj: string;
toastNotifications: CoreSetup['notifications']['toasts'];
}
Expand All @@ -60,6 +65,7 @@ type Props = PropsWithHeader | PropsWithoutHeader;
export const DataGrid: FC<Props> = memo(
(props) => {
const {
baseline,
chartsVisible,
chartsButtonVisible,
columnsWithCharts,
Expand All @@ -80,8 +86,10 @@ export const DataGrid: FC<Props> = memo(
toastNotifications,
toggleChartVisibility,
visibleColumns,
predictionFieldName,
resultsField,
analysisType,
} = props;

// TODO Fix row hovering + bar highlighting
// const getRowProps = (item: any) => {
// return {
Expand All @@ -90,6 +98,45 @@ export const DataGrid: FC<Props> = memo(
// };
// };

const popOverContent = useMemo(() => {
return analysisType === ANALYSIS_CONFIG_TYPE.REGRESSION ||
analysisType === ANALYSIS_CONFIG_TYPE.CLASSIFICATION
? {
featureImportance: ({ children }: { cellContentsElement: any; children: any }) => {
const rowIndex = children?.props?.visibleRowIndex;
const row = data[rowIndex];
if (!row) return <div />;
// if resultsField for some reason is not available then use ml
const mlResultsField = resultsField ?? DEFAULT_RESULTS_FIELD;
const parsedFIArray = row[mlResultsField].feature_importance;
let predictedValue: string | number | undefined;
let topClasses: TopClasses = [];
if (
predictionFieldName !== undefined &&
row &&
row[mlResultsField][predictionFieldName] !== undefined
) {
predictedValue = row[mlResultsField][predictionFieldName];
topClasses = row[mlResultsField].top_classes;
}

return (
<DecisionPathPopover
analysisType={analysisType}
predictedValue={predictedValue}
baseline={baseline}
featureImportance={parsedFIArray}
topClasses={topClasses}
predictionFieldName={
predictionFieldName ? predictionFieldName.replace('_prediction', '') : undefined
}
/>
);
},
}
: undefined;
}, [baseline, data]);

useEffect(() => {
if (invalidSortingColumnns.length > 0) {
invalidSortingColumnns.forEach((columnId) => {
Expand Down Expand Up @@ -225,6 +272,7 @@ export const DataGrid: FC<Props> = memo(
}
: {}),
}}
popoverContents={popOverContent}
pagination={{
...pagination,
pageSizeOptions: [5, 10, 25],
Expand Down
Loading

0 comments on commit ac5a6ea

Please sign in to comment.