From fa1c296ef8db8df9d3294ffa1df2ae3e6582f5d3 Mon Sep 17 00:00:00 2001 From: Anupam Kumar Date: Wed, 20 Mar 2024 17:17:02 +0530 Subject: [PATCH 1/2] refactor and fix scoped context chat - added routes: /providers, /default-provider-key - merge ScopedContextChatProvider to ContextChatProvider - make use of ProviderService everywhere, which extends ProviderConfigService - 'file' provider type is now files__default - handle mimetype array parsing error in `context_chat:scan` command Signed-off-by: Anupam Kumar --- appinfo/routes.php | 3 + lib/AppInfo/Application.php | 13 +- lib/BackgroundJobs/IndexerJob.php | 11 +- .../InitialContentImportJob.php | 10 +- lib/BackgroundJobs/SubmitContentJob.php | 6 +- lib/Command/Prompt.php | 45 ++--- lib/Command/ScanFiles.php | 5 + lib/Controller/ProviderController.php | 58 +++++++ lib/Listener/AppDisableListener.php | 8 +- lib/Listener/FileListener.php | 7 +- lib/Public/ContentManager.php | 14 +- lib/Service/LangRopeService.php | 49 +++--- lib/Service/ProviderConfigService.php | 13 +- lib/Service/ProviderService.php | 91 ++++++++++ lib/Service/ScanService.php | 43 +++-- lib/TextProcessing/ContextChatProvider.php | 158 +++++++++++++++++- .../ScopedContextChatProvider.php | 93 ----------- .../ScopedContextChatTaskType.php | 52 ------ tests/integration/ContentManagerTest.php | 39 ++--- .../integration/ProviderConfigServiceTest.php | 15 +- 20 files changed, 469 insertions(+), 264 deletions(-) create mode 100644 lib/Controller/ProviderController.php create mode 100644 lib/Service/ProviderService.php delete mode 100644 lib/TextProcessing/ScopedContextChatProvider.php delete mode 100644 lib/TextProcessing/ScopedContextChatTaskType.php diff --git a/appinfo/routes.php b/appinfo/routes.php index c5319f4..9bca5a8 100644 --- a/appinfo/routes.php +++ b/appinfo/routes.php @@ -13,5 +13,8 @@ 'routes' => [ ['name' => 'config#setConfig', 'url' => '/config', 'verb' => 'PUT'], ['name' => 'config#setAdminConfig', 'url' => '/admin-config', 'verb' => 'PUT'], + + ['name' => 'provider#getProviders', 'url' => '/providers', 'verb' => 'GET'], + ['name' => 'provider#getDefaultProviderKey', 'url' => '/default-provider-key', 'verb' => 'GET'], ], ]; diff --git a/lib/AppInfo/Application.php b/lib/AppInfo/Application.php index b547e86..f786e21 100644 --- a/lib/AppInfo/Application.php +++ b/lib/AppInfo/Application.php @@ -11,9 +11,9 @@ use OCA\ContextChat\Listener\AppDisableListener; use OCA\ContextChat\Listener\FileListener; +use OCA\ContextChat\Service\ProviderConfigService; use OCA\ContextChat\TextProcessing\ContextChatProvider; use OCA\ContextChat\TextProcessing\FreePromptProvider; -use OCA\ContextChat\TextProcessing\ScopedContextChatProvider; use OCP\App\Events\AppDisableEvent; use OCP\AppFramework\App; use OCP\AppFramework\Bootstrap\IBootContext; @@ -24,6 +24,7 @@ use OCP\Files\Events\Node\NodeCreatedEvent; use OCP\Files\Events\Node\NodeWrittenEvent; use OCP\Files\Events\NodeRemovedFromCache; +use OCP\IConfig; use OCP\Share\Events\ShareCreatedEvent; use OCP\Share\Events\ShareDeletedEvent; @@ -58,8 +59,13 @@ class Application extends App implements IBootstrap { 'text/org', ]; + private IConfig $config; + public function __construct(array $urlParams = []) { parent::__construct(self::APP_ID, $urlParams); + + $container = $this->getContainer(); + $this->config = $container->get(IConfig::class); } public function register(IRegistrationContext $context): void { @@ -73,7 +79,10 @@ public function register(IRegistrationContext $context): void { $context->registerEventListener(AppDisableEvent::class, AppDisableListener::class); $context->registerTextProcessingProvider(ContextChatProvider::class); $context->registerTextProcessingProvider(FreePromptProvider::class); - $context->registerTextProcessingProvider(ScopedContextChatProvider::class); + + $providerConfigService = new ProviderConfigService($this->config); + /** @psalm-suppress ArgumentTypeCoercion, UndefinedClass */ + $providerConfigService->updateProvider('files', 'default', '', true); } public function boot(IBootContext $context): void { diff --git a/lib/BackgroundJobs/IndexerJob.php b/lib/BackgroundJobs/IndexerJob.php index c3f0d0b..2b15e0a 100644 --- a/lib/BackgroundJobs/IndexerJob.php +++ b/lib/BackgroundJobs/IndexerJob.php @@ -9,6 +9,7 @@ use OCA\ContextChat\Db\QueueFile; use OCA\ContextChat\Service\LangRopeService; +use OCA\ContextChat\Service\ProviderService; use OCA\ContextChat\Service\QueueService; use OCA\ContextChat\Service\StorageService; use OCA\ContextChat\Type\Source; @@ -125,7 +126,15 @@ protected function index(array $files): void { $userIds = $this->storageService->getUsersForFileId($queueFile->getFileId()); foreach ($userIds as $userId) { try { - $source = new Source($userId, 'file: ' . $file->getId(), $file->getPath(), $fileHandle, $file->getMtime(), $file->getMimeType(), 'file'); + $source = new Source( + $userId, + ProviderService::getSourceId($file->getId()), + $file->getPath(), + $fileHandle, + $file->getMtime(), + $file->getMimeType(), + ProviderService::getDefaultProviderKey(), + ); } catch (InvalidPathException|NotFoundException $e) { $this->logger->error('Could not find file ' . $file->getPath(), ['exception' => $e]); continue 2; diff --git a/lib/BackgroundJobs/InitialContentImportJob.php b/lib/BackgroundJobs/InitialContentImportJob.php index 50b7f92..1d0ada6 100644 --- a/lib/BackgroundJobs/InitialContentImportJob.php +++ b/lib/BackgroundJobs/InitialContentImportJob.php @@ -14,7 +14,7 @@ namespace OCA\ContextChat\BackgroundJobs; use OCA\ContextChat\Public\IContentProvider; -use OCA\ContextChat\Service\ProviderConfigService; +use OCA\ContextChat\Service\ProviderService; use OCP\App\IAppManager; use OCP\AppFramework\Utility\ITimeFactory; use OCP\BackgroundJob\QueuedJob; @@ -27,7 +27,7 @@ class InitialContentImportJob extends QueuedJob { public function __construct( private IAppManager $appManager, - private ProviderConfigService $configService, + private ProviderService $providerService, private LoggerInterface $logger, private IUserManager $userMan, ITimeFactory $timeFactory, @@ -57,8 +57,8 @@ protected function run($argument): void { return; } - $registeredProviders = $this->configService->getProviders(); - $identifier = ProviderConfigService::getConfigKey($providerObj->getAppId(), $providerObj->getId()); + $registeredProviders = $this->providerService->getProviders(); + $identifier = ProviderService::getConfigKey($providerObj->getAppId(), $providerObj->getId()); if (!isset($registeredProviders[$identifier]) || $registeredProviders[$identifier]['isInitiated'] ) { @@ -66,6 +66,6 @@ protected function run($argument): void { } $providerObj->triggerInitialImport(); - $this->configService->updateProvider($providerObj->getAppId(), $providerObj->getId(), $argument, true); + $this->providerService->updateProvider($providerObj->getAppId(), $providerObj->getId(), $argument, true); } } diff --git a/lib/BackgroundJobs/SubmitContentJob.php b/lib/BackgroundJobs/SubmitContentJob.php index 8916199..cf2533e 100644 --- a/lib/BackgroundJobs/SubmitContentJob.php +++ b/lib/BackgroundJobs/SubmitContentJob.php @@ -16,7 +16,7 @@ use OCA\ContextChat\Db\QueueContentItem; use OCA\ContextChat\Db\QueueContentItemMapper; use OCA\ContextChat\Service\LangRopeService; -use OCA\ContextChat\Service\ProviderConfigService; +use OCA\ContextChat\Service\ProviderService; use OCA\ContextChat\Type\Source; use OCP\AppFramework\Utility\ITimeFactory; use OCP\BackgroundJob\IJobList; @@ -58,8 +58,8 @@ protected function run($argument): void { foreach ($bucketed as $userId => $entities) { $sources = array_map(function (QueueContentItem $item) use ($userId) { - $providerKey = ProviderConfigService::getConfigKey($item->getAppId(), $item->getProviderId()); - $sourceId = $providerKey . ': ' . $item->getItemId(); + $providerKey = ProviderService::getConfigKey($item->getAppId(), $item->getProviderId()); + $sourceId = ProviderService::getSourceId($item->getItemId(), $providerKey); return new Source( $userId, $sourceId, diff --git a/lib/Command/Prompt.php b/lib/Command/Prompt.php index 6924fc9..109bf82 100644 --- a/lib/Command/Prompt.php +++ b/lib/Command/Prompt.php @@ -13,7 +13,6 @@ namespace OCA\ContextChat\Command; use OCA\ContextChat\TextProcessing\ContextChatTaskType; -use OCA\ContextChat\TextProcessing\ScopedContextChatTaskType; use OCA\ContextChat\Type\ScopeType; use OCP\TextProcessing\FreePromptTaskType; use OCP\TextProcessing\IManager; @@ -80,26 +79,30 @@ protected function execute(InputInterface $input, OutputInterface $output) { throw new \InvalidArgumentException('Cannot use --context-sources with --context-provider'); } - if ($noContext) { - $task = new Task(FreePromptTaskType::class, $prompt, 'context_chat', $userId); - } elseif (!empty($contextSources)) { - $contextSources = preg_replace('/\s*,+\s*/', ',', $contextSources); - $contextSourcesArray = array_filter(explode(',', $contextSources), fn ($source) => !empty($source)); - $task = new Task(ScopedContextChatTaskType::class, json_encode([ - 'scopeType' => ScopeType::SOURCE, - 'scopeList' => $contextSourcesArray, - 'prompt' => $prompt, - ]), 'context_chat', $userId); - } elseif (!empty($contextProviders)) { - $contextProviders = preg_replace('/\s*,+\s*/', ',', $contextProviders); - $contextProvidersArray = array_filter(explode(',', $contextProviders), fn ($source) => !empty($source)); - $task = new Task(ScopedContextChatTaskType::class, json_encode([ - 'scopeType' => ScopeType::PROVIDER, - 'scopeList' => $contextProvidersArray, - 'prompt' => $prompt, - ]), 'context_chat', $userId); - } else { - $task = new Task(ContextChatTaskType::class, $prompt, 'context_chat', $userId); + try { + if ($noContext) { + $task = new Task(FreePromptTaskType::class, $prompt, 'context_chat', $userId); + } elseif (!empty($contextSources)) { + $contextSources = preg_replace('/\s*,+\s*/', ',', $contextSources); + $contextSourcesArray = array_filter(explode(',', $contextSources), fn ($source) => !empty($source)); + $task = new Task(ContextChatTaskType::class, json_encode([ + 'scopeType' => ScopeType::SOURCE, + 'scopeList' => $contextSourcesArray, + 'prompt' => $prompt, + ], JSON_THROW_ON_ERROR), 'context_chat', $userId); + } elseif (!empty($contextProviders)) { + $contextProviders = preg_replace('/\s*,+\s*/', ',', $contextProviders); + $contextProvidersArray = array_filter(explode(',', $contextProviders), fn ($source) => !empty($source)); + $task = new Task(ContextChatTaskType::class, json_encode([ + 'scopeType' => ScopeType::PROVIDER, + 'scopeList' => $contextProvidersArray, + 'prompt' => $prompt, + ], JSON_THROW_ON_ERROR), 'context_chat', $userId); + } else { + $task = new Task(ContextChatTaskType::class, json_encode([ 'prompt' => $prompt ], JSON_THROW_ON_ERROR), 'context_chat', $userId); + } + } catch (\JsonException $e) { + throw new \InvalidArgumentException('Invalid input, cannot encode JSON', intval($e->getCode()), $e); } $this->textProcessingManager->runTask($task); diff --git a/lib/Command/ScanFiles.php b/lib/Command/ScanFiles.php index 7c6b6ed..e972410 100644 --- a/lib/Command/ScanFiles.php +++ b/lib/Command/ScanFiles.php @@ -44,6 +44,11 @@ protected function execute(InputInterface $input, OutputInterface $output) { ? explode(',', $input->getOption('mimetype')) : Application::MIMETYPES; + if ($mimeTypeFilter === false) { + $output->writeln('Invalid mime type filter'); + return 1; + } + $userId = $input->getArgument('user_id'); $scan = $this->scanService->scanUserFiles($userId, $mimeTypeFilter); foreach ($scan as $s) { diff --git a/lib/Controller/ProviderController.php b/lib/Controller/ProviderController.php new file mode 100644 index 0000000..2600a9a --- /dev/null +++ b/lib/Controller/ProviderController.php @@ -0,0 +1,58 @@ + + * + * @author Anupam Kumar + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +namespace OCA\ContextChat\Controller; + +use OCA\ContextChat\Service\ProviderService; +use OCP\AppFramework\Controller; +use OCP\AppFramework\Http\Attribute\NoAdminRequired; +use OCP\AppFramework\Http\DataResponse; +use OCP\IRequest; + +class ProviderController extends Controller { + + public function __construct( + string $appName, + IRequest $request, + private ProviderService $providerService, + ) { + parent::__construct($appName, $request); + } + + /** + * @return DataResponse + */ + #[NoAdminRequired] + public function getDefaultProviderKey(): DataResponse { + $providerKey = $this->providerService->getDefaultProviderKey(); + return new DataResponse($providerKey); + } + + /** + * @return DataResponse + */ + #[NoAdminRequired] + public function getProviders(): DataResponse { + $providers = $this->providerService->getEnrichedProviders(); + return new DataResponse($providers); + } +} diff --git a/lib/Listener/AppDisableListener.php b/lib/Listener/AppDisableListener.php index c94c9a8..21184d8 100644 --- a/lib/Listener/AppDisableListener.php +++ b/lib/Listener/AppDisableListener.php @@ -14,7 +14,7 @@ namespace OCA\ContextChat\Listener; use OCA\ContextChat\Service\LangRopeService; -use OCA\ContextChat\Service\ProviderConfigService; +use OCA\ContextChat\Service\ProviderService; use OCP\App\Events\AppDisableEvent; use OCP\EventDispatcher\Event; use OCP\EventDispatcher\IEventListener; @@ -25,7 +25,7 @@ */ class AppDisableListener implements IEventListener { public function __construct( - private ProviderConfigService $configService, + private ProviderService $providerService, private LangRopeService $service, private LoggerInterface $logger, ) { @@ -36,7 +36,7 @@ public function handle(Event $event): void { return; } - foreach ($this->configService->getProviders() as $key => $values) { + foreach ($this->providerService->getProviders() as $key => $values) { /** @var string[] */ $identifierValues = explode('__', $key, 2); @@ -51,7 +51,7 @@ public function handle(Event $event): void { continue; } - $this->configService->removeProvider($appId, $providerId); + $this->providerService->removeProvider($appId, $providerId); $this->service->deleteSourcesByProviderForAllUsers($providerId); } } diff --git a/lib/Listener/FileListener.php b/lib/Listener/FileListener.php index 26c39e2..21979ff 100644 --- a/lib/Listener/FileListener.php +++ b/lib/Listener/FileListener.php @@ -10,6 +10,7 @@ use OCA\ContextChat\AppInfo\Application; use OCA\ContextChat\Db\QueueFile; use OCA\ContextChat\Service\LangRopeService; +use OCA\ContextChat\Service\ProviderService; use OCA\ContextChat\Service\QueueService; use OCA\ContextChat\Service\StorageService; use OCP\DB\Exception; @@ -120,7 +121,7 @@ public function handle(Event $event): void { if (!$node instanceof File) { continue; } - $fileRefs[] = 'file: ' . $node->getId(); + $fileRefs[] = ProviderService::getSourceId($node->getId()); } $this->langRopeService->deleteSources($userId, $fileRefs); @@ -129,7 +130,7 @@ public function handle(Event $event): void { return; } - $fileRef = 'file: ' . $node->getId(); + $fileRef = ProviderService::getSourceId($node->getId()); foreach ($userIds as $userId) { $this->langRopeService->deleteSources($userId, [$fileRef]); } @@ -194,7 +195,7 @@ public function postDelete(Node $node, bool $recurse = true): void { } foreach ($this->storageService->getUsersForFileId($node->getId()) as $userId) { - $fileRef = 'file: ' . $node->getId(); + $fileRef = ProviderService::getSourceId($node->getId()); $this->langRopeService->deleteSources($userId, [$fileRef]); } } diff --git a/lib/Public/ContentManager.php b/lib/Public/ContentManager.php index ca02c78..1debc0e 100644 --- a/lib/Public/ContentManager.php +++ b/lib/Public/ContentManager.php @@ -16,7 +16,7 @@ use OCA\ContextChat\Db\QueueContentItem; use OCA\ContextChat\Db\QueueContentItemMapper; use OCA\ContextChat\Service\LangRopeService; -use OCA\ContextChat\Service\ProviderConfigService; +use OCA\ContextChat\Service\ProviderService; use OCP\BackgroundJob\IJobList; use OCP\Server; use Psr\Container\ContainerExceptionInterface; @@ -26,7 +26,7 @@ class ContentManager { public function __construct( private IJobList $jobList, - private ProviderConfigService $configService, + private ProviderService $providerService, private LangRopeService $service, private QueueContentItemMapper $mapper, private LoggerInterface $logger, @@ -47,11 +47,11 @@ public function registerContentProvider(string $providerClass): void { return; } - if ($this->configService->hasProvider($providerObj->getAppId(), $providerObj->getId())) { + if ($this->providerService->hasProvider($providerObj->getAppId(), $providerObj->getId())) { return; } - $this->configService->updateProvider($providerObj->getAppId(), $providerObj->getId(), $providerClass); + $this->providerService->updateProvider($providerObj->getAppId(), $providerObj->getId(), $providerClass); if (!$this->jobList->has(InitialContentImportJob::class, $providerClass)) { $this->jobList->add(InitialContentImportJob::class, $providerClass); @@ -98,7 +98,9 @@ public function submitContent(string $appId, array $items): void { */ public function removeContentForUsers(string $appId, string $providerId, string $itemId, array $users): void { foreach ($users as $userId) { - $this->service->deleteSources($userId, [$this->configService->getConfigKey($appId, $providerId) . ": $itemId"]); + $this->service->deleteSources($userId, [ + ProviderService::getSourceId($itemId, ProviderService::getConfigKey($appId, $providerId)) + ]); } } @@ -112,7 +114,7 @@ public function removeContentForUsers(string $appId, string $providerId, string */ public function removeAllContentForUsers(string $appId, string $providerId, array $users): void { foreach ($users as $userId) { - $this->service->deleteSourcesByProvider($userId, $this->configService->getConfigKey($appId, $providerId)); + $this->service->deleteSourcesByProvider($userId, ProviderService::getConfigKey($appId, $providerId)); } } } diff --git a/lib/Service/LangRopeService.php b/lib/Service/LangRopeService.php index 82517f1..2d6a073 100644 --- a/lib/Service/LangRopeService.php +++ b/lib/Service/LangRopeService.php @@ -35,11 +35,18 @@ public function __construct( private IAppManager $appManager, private IURLGenerator $urlGenerator, private IUserManager $userMan, - private ProviderConfigService $configService, private ?string $userId, ) { } + /** + * @param string $route + * @param string $method + * @param array $params + * @param string|null $contentType + * @return array + * @throws RuntimeException + */ private function requestToExApp( string $route, string $method = 'POST', @@ -121,6 +128,7 @@ private function requestToExApp( * @param string $userId * @param string $providerKey * @return void + * @throws RuntimeException */ public function deleteSourcesByProvider(string $userId, string $providerKey): void { $params = [ @@ -134,6 +142,7 @@ public function deleteSourcesByProvider(string $userId, string $providerKey): vo /** * @param string $providerKey * @return void + * @throws RuntimeException */ public function deleteSourcesByProviderForAllUsers(string $providerKey): void { $params = [ @@ -147,6 +156,7 @@ public function deleteSourcesByProviderForAllUsers(string $providerKey): void { * @param string $userId * @param string[] $sourceNames * @return void + * @throws RuntimeException */ public function deleteSources(string $userId, array $sourceNames): void { if (count($sourceNames) === 0) { @@ -164,6 +174,7 @@ public function deleteSources(string $userId, array $sourceNames): void { /** * @param Source[] $sources * @return void + * @throws RuntimeException */ public function indexSources(array $sources): void { if (count($sources) === 0) { @@ -173,14 +184,14 @@ public function indexSources(array $sources): void { $params = array_map(function (Source $source) { return [ 'name' => 'sources', - 'filename' => $source->reference, // eg. 'file: 555' + 'filename' => $source->reference, // eg. 'files__default: 555' 'contents' => $source->content, 'headers' => [ 'userId' => $source->userId, 'title' => $source->title, 'type' => $source->type, 'modified' => $source->modified, - 'provider' => $source->provider, // eg. 'file' + 'provider' => $source->provider, // eg. 'files__default' ], ]; }, $sources); @@ -188,33 +199,28 @@ public function indexSources(array $sources): void { $this->requestToExApp('/loadSources', 'PUT', $params, 'multipart/form-data'); } - public function query(string $userId, string $prompt, bool $useContext = true): array { - $params = [ - 'query' => $prompt, - 'userId' => $userId, - 'useContext' => $useContext, - ]; - - $response = $this->requestToExApp('/query', 'GET', $params); - return ['message' => $this->getWithPresentableSources($response['output'] ?? '', ...($response['sources'] ?? []))]; - } - /** * @param string $userId * @param string $prompt - * @param string $scopeType - * @param array $scopeList + * @param bool $useContext + * @param ?string $scopeType + * @param ?array $scopeList * @return array + * @throws RuntimeException */ - public function scopedQuery(string $userId, string $prompt, string $scopeType, array $scopeList): array { + public function query(string $userId, string $prompt, bool $useContext = true, ?string $scopeType = null, ?array $scopeList = null): array { $params = [ 'query' => $prompt, 'userId' => $userId, - 'scopeType' => $scopeType, - 'scopeList' => $scopeList, + 'useContext' => $useContext, ]; + if ($scopeType !== null && $scopeList !== null) { + $params['useContext'] = true; + $params['scopeType'] = $scopeType; + $params['scopeList'] = $scopeList; + } - $response = $this->requestToExApp('/scopedQuery', 'POST', $params); + $response = $this->requestToExApp('/query', 'POST', $params); return ['message' => $this->getWithPresentableSources($response['output'] ?? '', ...($response['sources'] ?? []))]; } @@ -225,8 +231,9 @@ public function getWithPresentableSources(string $llmResponse, string ...$source $output = str_repeat(PHP_EOL, 3) . $this->l10n->t('Sources referenced to generate the above response:') . PHP_EOL; + $emptyFilesSourceId = ProviderService::getSourceId(''); foreach ($sourceRefs as $source) { - if (str_starts_with($source, 'file: ') && is_numeric($fileId = substr($source, 6))) { + if (str_starts_with($source, $emptyFilesSourceId) && is_numeric($fileId = substr($source, strlen($emptyFilesSourceId)))) { // use `overwritehost` setting in config.php to overwrite the host $output .= $this->urlGenerator->linkToRouteAbsolute('files.View.showFile', ['fileid' => $fileId]) . PHP_EOL; } elseif (str_contains($source, '__')) { diff --git a/lib/Service/ProviderConfigService.php b/lib/Service/ProviderConfigService.php index 6564c3f..dc229de 100644 --- a/lib/Service/ProviderConfigService.php +++ b/lib/Service/ProviderConfigService.php @@ -48,13 +48,14 @@ private function validateProvidersArray(array $providers): bool { */ public function getProviders(): array { $providers = []; + $providersString = $this->config->getAppValue(Application::APP_ID, 'providers', ''); - $providersString = $this->config->getAppValue(Application::APP_ID, 'providers'); if ($providersString !== '') { $providers = json_decode($providersString, true); if ($providers === null || !$this->validateProvidersArray($providers)) { $providers = []; + $this->config->setAppValue(Application::APP_ID, 'providers', ''); } } @@ -74,7 +75,7 @@ public function updateProvider( bool $isInitiated = false, ): void { $providers = $this->getProviders(); - $providers[$this->getConfigKey($appId, $providerId)] = [ + $providers[self::getConfigKey($appId, $providerId)] = [ 'isInitiated' => $isInitiated, 'classString' => $providerClass, ]; @@ -88,11 +89,11 @@ public function updateProvider( public function removeProvider(string $appId, ?string $providerId = null): void { $providers = $this->getProviders(); - if ($providerId !== null && isset($providers[$this->getConfigKey($appId, $providerId)])) { - unset($providers[$this->getConfigKey($appId, $providerId)]); + if ($providerId !== null && isset($providers[self::getConfigKey($appId, $providerId)])) { + unset($providers[self::getConfigKey($appId, $providerId)]); } elseif ($providerId === null) { foreach ($providers as $k => $v) { - if (str_starts_with($k, $appId)) { + if (str_starts_with($k, self::getConfigKey($appId, ''))) { unset($providers[$k]); } } @@ -108,6 +109,6 @@ public function removeProvider(string $appId, ?string $providerId = null): void */ public function hasProvider(string $appId, string $providerId): bool { $providers = $this->getProviders(); - return isset($providers[$this->getConfigKey($appId, $providerId)]); + return isset($providers[self::getConfigKey($appId, $providerId)]); } } diff --git a/lib/Service/ProviderService.php b/lib/Service/ProviderService.php new file mode 100644 index 0000000..7237af6 --- /dev/null +++ b/lib/Service/ProviderService.php @@ -0,0 +1,91 @@ + + */ + public function getEnrichedProviders(): array { + $providers = $this->providerService->getProviders(); + $sanitizedProviders = []; + + foreach ($providers as $providerKey => $metadata) { + // providerKey ($appId__$providerId) + /** @var string[] */ + $providerValues = explode('__', $providerKey, 2); + + if (count($providerValues) !== 2) { + $this->logger->info("Invalid provider key $providerKey, skipping"); + continue; + } + + [$appId, $providerId] = $providerValues; + + $user = $this->userId === null ? null : $this->userManager->get($this->userId); + if (!$this->appManager->isEnabledForUser($appId, $user)) { + $this->logger->info("App $appId is not enabled for user {$this->userId}, skipping"); + continue; + } + + $appInfo = $this->appManager->getAppInfo($appId); + if ($appInfo === null) { + $this->logger->info("Could not get app info for $appId, skipping"); + continue; + } + + try { + $icon = $this->urlGenerator->imagePath($appId, 'app-dark.svg'); + } catch (\RuntimeException $e) { + $this->logger->info("Could not get app image for $appId"); + $icon = ''; + } + + if (!isset($appInfo['name'])) { + $this->logger->info("App $appId does not have a name, skipping"); + continue; + } + + $sanitizedProviders[] = [ + 'id' => $providerKey, + 'label' => ucfirst($providerId) . ' - ' . $appInfo['name'], + 'icon' => $icon, + ]; + } + return $sanitizedProviders; + } +} diff --git a/lib/Service/ScanService.php b/lib/Service/ScanService.php index 76178a2..3b0a448 100644 --- a/lib/Service/ScanService.php +++ b/lib/Service/ScanService.php @@ -48,10 +48,6 @@ public function scanDirectory(string $userId, array $mimeTypeFilter, Folder $dir $size = 0; foreach ($directory->getDirectoryListing() as $node) { if ($node instanceof File) { - if (!in_array($node->getMimeType(), $mimeTypeFilter)) { - continue; - } - $node_size = $node->getSize(); if ($size + $node_size > Application::CC_MAX_SIZE || count($sources) >= Application::CC_MAX_FILES) { @@ -60,22 +56,11 @@ public function scanDirectory(string $userId, array $mimeTypeFilter, Folder $dir $size = 0; } - try { - $fileHandle = $node->fopen('r'); - } catch (\Exception $e) { - $this->logger->error('Could not open file ' . $node->getPath() . ' for reading: ' . $e->getMessage()); + $source = $this->getSourceFromFile($userId, $mimeTypeFilter, $node); + if ($source === null) { continue; } - $source = new Source( - $userId, - 'file: ' . $node->getId(), - $node->getPath(), - $fileHandle, - $node->getMTime(), - $node->getMimeType(), - 'file' - ); $sources[] = $source; $size += $node_size; @@ -97,6 +82,30 @@ public function scanDirectory(string $userId, array $mimeTypeFilter, Folder $dir return []; } + public function getSourceFromFile(string $userId, array $mimeTypeFilter, File $node): Source | null { + if (!in_array($node->getMimeType(), $mimeTypeFilter)) { + return null; + } + + try { + $fileHandle = $node->fopen('r'); + } catch (\Exception $e) { + $this->logger->error('Could not open file ' . $node->getPath() . ' for reading: ' . $e->getMessage()); + return null; + } + + $providerKey = ProviderService::getDefaultProviderKey(); + return new Source( + $userId, + $providerKey . ': ' . $node->getId(), + $node->getPath(), + $fileHandle, + $node->getMTime(), + $node->getMimeType(), + $providerKey, + ); + } + public function indexSources(array $sources): void { $this->langRopeService->indexSources($sources); } diff --git a/lib/TextProcessing/ContextChatProvider.php b/lib/TextProcessing/ContextChatProvider.php index 13521ab..515950c 100644 --- a/lib/TextProcessing/ContextChatProvider.php +++ b/lib/TextProcessing/ContextChatProvider.php @@ -3,10 +3,20 @@ declare(strict_types=1); namespace OCA\ContextChat\TextProcessing; +use OCA\ContextChat\AppInfo\Application; use OCA\ContextChat\Service\LangRopeService; +use OCA\ContextChat\Service\ProviderService; +use OCA\ContextChat\Service\ScanService; +use OCA\ContextChat\Type\ScopeType; +use OCP\Files\File; +use OCP\Files\Folder; +use OCP\Files\IRootFolder; +use OCP\Files\NotPermittedException; use OCP\IL10N; use OCP\TextProcessing\IProvider; use OCP\TextProcessing\IProviderWithUserId; +use Psr\Log\LoggerInterface; +use RuntimeException; /** * @template-implements IProviderWithUserId @@ -17,8 +27,11 @@ class ContextChatProvider implements IProvider, IProviderWithUserId { private ?string $userId = null; public function __construct( - private LangRopeService $langRopeService, private IL10N $l10n, + private IRootFolder $rootFolder, + private LoggerInterface $logger, + private LangRopeService $langRopeService, + private ScanService $scanService, ) { } @@ -26,18 +39,157 @@ public function getName(): string { return $this->l10n->t('Nextcloud Assistant Context Chat Provider'); } + /** + * Accepted scopeList formats: + * - "files__default: $fileId" + * - "$appId__$providerId" + * + * @param string $prompt JSON string with the following structure: + * { + * "scopeType": string, (optional key) + * "scopeList": list[string], (optional key) + * "prompt": string + * } + * + * @return string + */ public function process(string $prompt): string { if ($this->userId === null) { throw new \RuntimeException('User ID is required to process the prompt.'); } - $response = $this->langRopeService->query($this->userId, $prompt); + try { + $parsedData = json_decode($prompt, true, flags: JSON_THROW_ON_ERROR | JSON_INVALID_UTF8_IGNORE); + } catch (\JsonException $e) { + throw new \RuntimeException( + 'Invalid JSON string, expected { "prompt": string } or { "scopeType": string, "scopeList": list[string], "prompt": string }', + intval($e->getCode()), $e, + ); + } + + if (!isset($parsedData['prompt']) || !is_string($parsedData['prompt'])) { + throw new \RuntimeException('Invalid JSON string, expected "prompt" key with string value'); + } + + if (!isset($parsedData['scopeType']) || !isset($parsedData['scopeList'])) { + $response = $this->langRopeService->query($this->userId, $prompt); + if (isset($response['error'])) { + throw new \RuntimeException('No result in ContextChat response. ' . $response['error']); + } + return $response['message'] ?? ''; + } + + if (!is_string($parsedData['scopeType']) || !is_array($parsedData['scopeList'])) { + throw new \RuntimeException('Invalid JSON string, expected "scopeType" key with string value and "scopeList" key with array value'); + } + + try { + ScopeType::validate($parsedData['scopeType']); + } catch (\InvalidArgumentException $e) { + throw new \RuntimeException($e->getMessage(), intval($e->getCode()), $e); + } + + $scopeList = array_unique($parsedData['scopeList']); + if (count($scopeList) === 0) { + throw new \RuntimeException('No sources found'); + } + + // index sources before the query, not needed for providers + if ($parsedData['scopeType'] === ScopeType::SOURCE) { + $processedScopes = $this->indexFiles(...$parsedData['scopeList']); + $this->logger->debug('All valid files indexed, querying ContextChat', ['scopeType' => $parsedData['scopeType'], 'scopeList' => $processedScopes]); + } else { + $processedScopes = $scopeList; + $this->logger->debug('No need to index sources, querying ContextChat', ['scopeType' => $parsedData['scopeType'], 'scopeList' => $processedScopes]); + } + + $response = $this->langRopeService->query( + $this->userId, + $parsedData['prompt'], + true, + $parsedData['scopeType'], + $processedScopes, + ); + if (isset($response['error'])) { - throw new \RuntimeException('No result in ContextChat response. ' . $response['error']); + throw new \RuntimeException('No result in ContextChat response: ' . $response['error']); } + return $response['message'] ?? ''; } + /** + * @param array scopeList + * @return array List of indexed files + */ + private function indexFiles(string ...$scopeList): array { + $nodes = []; + $indexedFiles = []; + + foreach ($scopeList as $scope) { + if (!str_contains($scope, ProviderService::getSourceId(''))) { + $this->logger->warning('Invalid source format, expected "sourceId: itemId"'); + continue; + } + + $nodeId = substr($scope, strlen(ProviderService::getSourceId(''))); + + try { + $userFolder = $this->rootFolder->getUserFolder($this->userId); + } catch (NotPermittedException $e) { + $this->logger->warning('Could not get user folder for user ' . $this->userId . ': ' . $e->getMessage()); + continue; + } + $node = $userFolder->getById(intval($nodeId)); + if (count($node) === 0) { + $this->logger->warning('Could not find file/folder with ID ' . $nodeId . ', skipping'); + continue; + } + $node = $node[0]; + + if (!$node instanceof File && !$node instanceof Folder) { + $this->logger->warning('Invalid source type, expected file/folder'); + continue; + } + + $nodes[] = [ + 'scope' => $scope, + 'node' => $node, + 'path' => $node->getPath(), + ]; + } + + // remove subfolders + $filteredNodes = $nodes; + foreach ($nodes as $node) { + if ($node['node'] instanceof Folder) { + $filteredNodes = array_filter($filteredNodes, function ($n) use ($node) { + return !str_starts_with($n['path'], $node['path'] . DIRECTORY_SEPARATOR); + }); + $filteredNodes[] = $node; + } + } + + foreach ($filteredNodes as $node) { + try { + if ($node['node'] instanceof File) { + $source = $this->scanService->getSourceFromFile($this->userId, Application::MIMETYPES, $node['node']); + $this->scanService->indexSources([$source]); + $indexedFiles[] = $node['scope']; + } elseif ($node['node'] instanceof Folder) { + $indexedFiles = array_merge( + $indexedFiles, + iterator_to_array($this->scanService->scanDirectory($this->userId, Application::MIMETYPES, $node['node'])), + ); + } + } catch (RuntimeException $e) { + $this->logger->warning('Could not index file/folder with ID ' . $node['node']->getId() . ': ' . $e->getMessage()); + } + } + + return $indexedFiles; + } + public function getTaskType(): string { return ContextChatTaskType::class; } diff --git a/lib/TextProcessing/ScopedContextChatProvider.php b/lib/TextProcessing/ScopedContextChatProvider.php deleted file mode 100644 index 3768e92..0000000 --- a/lib/TextProcessing/ScopedContextChatProvider.php +++ /dev/null @@ -1,93 +0,0 @@ - - * @template-implements IProvider - */ -class ScopedContextChatProvider implements IProvider, IProviderWithUserId { - - private ?string $userId = null; - - public function __construct( - private LangRopeService $langRopeService, - private IL10N $l10n, - ) { - } - - public function getName(): string { - return $this->l10n->t('Nextcloud Assistant Scoped Context Chat Provider'); - } - - /** - * @param string $prompt JSON string with the following structure: - * { - * "scopeType": string, - * "scopeList": list[string], - * "prompt": string, - * } - * - * @return string - */ - public function process(string $prompt): string { - if ($this->userId === null) { - throw new \RuntimeException('User ID is required to process the prompt.'); - } - - try { - $parsedData = json_decode($prompt, true, flags: JSON_THROW_ON_ERROR | JSON_INVALID_UTF8_IGNORE); - } catch (\JsonException $e) { - throw new \RuntimeException( - 'Invalid JSON string, expected { "scopeType": string, "scopeList": list[string], "prompt": string }', - intval($e->getCode()), $e, - ); - } - - if ( - !is_array($parsedData) - || !isset($parsedData['scopeType']) - || !is_string($parsedData['scopeType']) - || !isset($parsedData['scopeList']) - || !is_array($parsedData['scopeList']) - || !isset($parsedData['prompt']) - || !is_string($parsedData['prompt']) - ) { - throw new \RuntimeException('Invalid JSON string, expected { "scopeType": string, "scopeList": list[string], "prompt": string }'); - } - - try { - ScopeType::validate($parsedData['scopeType']); - } catch (\InvalidArgumentException $e) { - throw new \RuntimeException($e->getMessage(), intval($e->getCode()), $e); - } - - $response = $this->langRopeService->scopedQuery( - $this->userId, - $parsedData['prompt'], - $parsedData['scopeType'], - $parsedData['scopeList'], - ); - - if (isset($response['error'])) { - throw new \RuntimeException('No result in ContextChat response. ' . $response['error']); - } - - return $response['message'] ?? ''; - } - - public function getTaskType(): string { - return ScopedContextChatTaskType::class; - } - - public function setUserId(?string $userId): void { - $this->userId = $userId; - } -} diff --git a/lib/TextProcessing/ScopedContextChatTaskType.php b/lib/TextProcessing/ScopedContextChatTaskType.php deleted file mode 100644 index fc204de..0000000 --- a/lib/TextProcessing/ScopedContextChatTaskType.php +++ /dev/null @@ -1,52 +0,0 @@ - - * - * @author Julien Veyssier - * - * @license GNU AGPL version 3 or any later version - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as - * published by the Free Software Foundation, either version 3 of the - * License, or (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -namespace OCA\ContextChat\TextProcessing; - -use OCP\IL10N; -use OCP\TextProcessing\ITaskType; - -class ScopedContextChatTaskType implements ITaskType { - public function __construct( - private IL10N $l, - ) { - } - - /** - * @inheritDoc - * @since 27.1.0 - */ - public function getName(): string { - return $this->l->t('Scoped Context Chat'); - } - - /** - * @inheritDoc - * @since 27.1.0 - */ - public function getDescription(): string { - return $this->l->t('Ask a question about the data selected by you.'); - } -} diff --git a/tests/integration/ContentManagerTest.php b/tests/integration/ContentManagerTest.php index cc1e70c..9a30668 100644 --- a/tests/integration/ContentManagerTest.php +++ b/tests/integration/ContentManagerTest.php @@ -22,7 +22,7 @@ use OCA\ContextChat\Public\ContentManager; use OCA\ContextChat\Public\IContentProvider; use OCA\ContextChat\Service\LangRopeService; -use OCA\ContextChat\Service\ProviderConfigService; +use OCA\ContextChat\Service\ProviderService; use OCP\BackgroundJob\IJobList; use OCP\Server; use PHPUnit\Framework\MockObject\MockObject; @@ -32,8 +32,8 @@ class ContentManagerTest extends TestCase { /** @var MockObject | QueueContentItemMapper */ private QueueContentItemMapper $mapper; - /** @var MockObject | ProviderConfigService */ - private ProviderConfigService $configService; + /** @var MockObject | ProviderService */ + private ProviderService $providerService; /** @var MockObject | LangRopeService */ private LangRopeService $service; @@ -41,30 +41,34 @@ class ContentManagerTest extends TestCase { private LoggerInterface $logger; private IJobList $jobList; - private bool $initCalled = false; + // private bool $initCalled = false; private static string $providerClass = 'OCA\ContextChat\Tests\ContentProvider'; public function setUp(): void { $this->jobList = Server::get(IJobList::class); $this->logger = Server::get(LoggerInterface::class); $this->mapper = $this->createMock(QueueContentItemMapper::class); - $this->configService = $this->createMock(ProviderConfigService::class); + $this->providerService = $this->createMock(ProviderService::class); $this->service = $this->createMock(LangRopeService::class); - $this->configService + $this->providerService ->method('getProviders') ->willReturn([ - ProviderConfigService::getConfigKey(Application::APP_ID, 'test-provider') => [ + ProviderService::getDefaultProviderKey() => [ + 'isInitiated' => true, + 'classString' => '', + ], + ProviderService::getConfigKey(Application::APP_ID, 'test-provider') => [ 'isInitiated' => false, 'classString' => static::$providerClass, ], ]); - $this->overwriteService(ProviderConfigService::class, $this->configService); + // $this->overwriteService(ProviderConfigService::class, $this->providerConfigService); // using this app's app id to pass the check that the app is enabled for the user $providerObj = new ContentProvider(Application::APP_ID, 'test-provider', function () { - $this->initCalled = true; + // $this->initCalled = true; }); $providerClass = get_class($providerObj); @@ -74,7 +78,7 @@ public function setUp(): void { $this->contentManager = new ContentManager( $this->jobList, - $this->configService, + $this->providerService, $this->service, $this->mapper, $this->logger, @@ -102,14 +106,14 @@ public function testRegisterContentProvider( string $providerId, bool $registrationSuccessful, ): void { - $this->configService + $this->providerService ->expects($registrationSuccessful ? $this->once() : $this->never()) ->method('hasProvider') ->with($appId, $providerId) ->willReturn(false); - $this->configService - ->expects($registrationSuccessful ? $this->exactly(2) : $this->never()) + $this->providerService + ->expects($registrationSuccessful ? $this->once() : $this->never()) ->method('updateProvider') ->with($appId, $providerId, $providerClass); @@ -118,15 +122,8 @@ public function testRegisterContentProvider( $jobsIter = $this->jobList->getJobsIterator(InitialContentImportJob::class, 1, 0); if ($registrationSuccessful) { $this->assertNotNull($jobsIter); + $this->jobList->remove(InitialContentImportJob::class, $providerClass); } - - foreach ($jobsIter as $job) { - if ($job->getArgument() === $providerClass) { - $job->execute($this->jobList); - } - } - - $this->assertTrue(($this->initCalled && $registrationSuccessful) || (!$this->initCalled && !$registrationSuccessful)); } public function testSubmitContent(): void { diff --git a/tests/integration/ProviderConfigServiceTest.php b/tests/integration/ProviderConfigServiceTest.php index 200a566..7954a11 100644 --- a/tests/integration/ProviderConfigServiceTest.php +++ b/tests/integration/ProviderConfigServiceTest.php @@ -40,11 +40,11 @@ public function testGetConfigKey(): void { public static function dataBank(): array { $validData = [ - 'app1__provider1' => [ + ProviderConfigService::getConfigKey('app1', 'provider1') => [ 'isInitiated' => true, 'classString' => 'class1', ], - 'app1__provider2' => [ + ProviderConfigService::getConfigKey('app1', 'provider2') => [ 'isInitiated' => false, 'classString' => 'class2', ], @@ -105,9 +105,9 @@ public function testUpdateProvider(string $returnVal, array $providers): void { ->willReturn($returnVal); $this->config - ->expects($this->once()) + ->expects($returnVal === 'invalid' ? $this->exactly(2) : $this->once()) ->method('setAppValue') - ->with(Application::APP_ID, 'providers', $setProvidersValue); + ->with(Application::APP_ID, 'providers', $this->logicalOr($this->equalTo(''), $this->equalTo($setProvidersValue))); $this->configService->updateProvider($appId, $providerId, $providerClass, $isInitiated); } @@ -134,9 +134,12 @@ public function testRemoveProvider(string $returnVal, array $providers): void { } $this->config - ->expects($this->once()) + ->expects($returnVal === 'invalid' ? $this->exactly(2) : $this->once()) ->method('setAppValue') - ->with(Application::APP_ID, 'providers', json_encode($providers)); + ->with(Application::APP_ID, 'providers', $this->logicalOr( + $this->equalTo(''), + $this->equalTo(json_encode($providers)) + )); $this->configService->removeProvider($appId, $providerId); } From b801406c271be896606af3c525dc87a9545034cb Mon Sep 17 00:00:00 2001 From: Anupam Kumar Date: Thu, 21 Mar 2024 14:23:54 +0530 Subject: [PATCH 2/2] integration test updates & fixes Signed-off-by: Anupam Kumar --- .github/workflows/integration-test.yml | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index 59e63c1..692ec51 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -20,7 +20,7 @@ concurrency: jobs: transcription: - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 strategy: # do not stop on another job's failure @@ -155,15 +155,18 @@ jobs: - name: Install and init backend run: | cd context_chat_backend - pip install --no-deps -r reqs.txt + pip install --no-deps -r requirements.txt + pip install --upgrade pip setuptools wheel + CMAKE_ARGS="-DLLAMA_OPENBLAS=on" pip install llama-cpp-python cp example.env .env + cp config.cpu.yaml config.yaml echo "NEXTCLOUD_URL=http://localhost:8080" >> .env ./main.py &> backend_logs & - name: Register backend run: | ./occ app_api:daemon:register --net host manual_install "Manual Install" manual-install http localhost http://localhost:8080 - ./occ app_api:app:register context_chat_backend manual_install --json-info "{\"appid\":\"context_chat_backend\",\"name\":\"Context Chat Backend\",\"daemon_config_name\":\"manual_install\",\"version\":\"1.1.1\",\"secret\":\"12345\",\"port\":10034,\"scopes\":[],\"system_app\":0}" --force-scopes --wait-finish + ./occ app_api:app:register context_chat_backend manual_install --json-info "{\"appid\":\"context_chat_backend\",\"name\":\"Context Chat Backend\",\"daemon_config_name\":\"manual_install\",\"version\":\"1.2.0\",\"secret\":\"12345\",\"port\":10034,\"scopes\":[],\"system_app\":0}" --force-scopes --wait-finish - name: Scan files run: |