From c94278be85caee5d264226c514b4e41fffcaa7c6 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Wed, 8 May 2024 20:24:40 -0400 Subject: [PATCH] =?UTF-8?q?=F0=9F=A6=99=20feat:=20Ollama=20Vision=20Suppor?= =?UTF-8?q?t=20(#2643)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor: checkVisionRequest, search availableModels for valid vision model instead of using default * feat: install ollama-js, add typedefs * feat: Ollama Vision Support * ci: fix test --- api/app/clients/OllamaClient.js | 154 +++++++++++++++++++++ api/app/clients/OpenAIClient.js | 67 +++++++-- api/app/clients/specs/OpenAIClient.test.js | 42 +++++- api/package.json | 1 + api/server/services/ModelService.js | 54 +------- api/server/services/ModelService.spec.js | 46 +----- api/typedefs.js | 21 ++- api/utils/deriveBaseURL.js | 28 ++++ api/utils/deriveBaseURL.spec.js | 74 ++++++++++ api/utils/index.js | 2 + package-lock.json | 16 ++- packages/data-provider/src/config.ts | 2 +- 12 files changed, 390 insertions(+), 117 deletions(-) create mode 100644 api/app/clients/OllamaClient.js create mode 100644 api/utils/deriveBaseURL.js create mode 100644 api/utils/deriveBaseURL.spec.js diff --git a/api/app/clients/OllamaClient.js b/api/app/clients/OllamaClient.js new file mode 100644 index 0000000000..57bc8754fb --- /dev/null +++ b/api/app/clients/OllamaClient.js @@ -0,0 +1,154 @@ +const { z } = require('zod'); +const axios = require('axios'); +const { Ollama } = require('ollama'); +const { deriveBaseURL } = require('~/utils'); +const { logger } = require('~/config'); + +const ollamaPayloadSchema = z.object({ + mirostat: z.number().optional(), + mirostat_eta: z.number().optional(), + mirostat_tau: z.number().optional(), + num_ctx: z.number().optional(), + repeat_last_n: z.number().optional(), + repeat_penalty: z.number().optional(), + temperature: z.number().optional(), + seed: z.number().nullable().optional(), + stop: z.array(z.string()).optional(), + tfs_z: z.number().optional(), + num_predict: z.number().optional(), + top_k: z.number().optional(), + top_p: z.number().optional(), + stream: z.optional(z.boolean()), + model: z.string(), +}); + +/** + * @param {string} imageUrl + * @returns {string} + * @throws {Error} + */ +const getValidBase64 = (imageUrl) => { + const parts = imageUrl.split(';base64,'); + + if (parts.length === 2) { + return parts[1]; + } else { + logger.error('Invalid or no Base64 string found in URL.'); + } +}; + +class OllamaClient { + constructor(options = {}) { + const host = deriveBaseURL(options.baseURL ?? 'http://localhost:11434'); + /** @type {Ollama} */ + this.client = new Ollama({ host }); + } + + /** + * Fetches Ollama models from the specified base API path. + * @param {string} baseURL + * @returns {Promise} The Ollama models. + */ + static async fetchModels(baseURL) { + let models = []; + if (!baseURL) { + return models; + } + try { + const ollamaEndpoint = deriveBaseURL(baseURL); + /** @type {Promise>} */ + const response = await axios.get(`${ollamaEndpoint}/api/tags`); + models = response.data.models.map((tag) => tag.name); + return models; + } catch (error) { + const logMessage = + 'Failed to fetch models from Ollama API. If you are not using Ollama directly, and instead, through some aggregator or reverse proxy that handles fetching via OpenAI spec, ensure the name of the endpoint doesn\'t start with `ollama` (case-insensitive).'; + logger.error(logMessage, error); + return []; + } + } + + /** + * @param {ChatCompletionMessage[]} messages + * @returns {OllamaMessage[]} + */ + static formatOpenAIMessages(messages) { + const ollamaMessages = []; + + for (const message of messages) { + if (typeof message.content === 'string') { + ollamaMessages.push({ + role: message.role, + content: message.content, + }); + continue; + } + + let aggregatedText = ''; + let imageUrls = []; + + for (const content of message.content) { + if (content.type === 'text') { + aggregatedText += content.text + ' '; + } else if (content.type === 'image_url') { + imageUrls.push(getValidBase64(content.image_url.url)); + } + } + + const ollamaMessage = { + role: message.role, + content: aggregatedText.trim(), + }; + + if (imageUrls.length > 0) { + ollamaMessage.images = imageUrls; + } + + ollamaMessages.push(ollamaMessage); + } + + return ollamaMessages; + } + + /*** + * @param {Object} params + * @param {ChatCompletionPayload} params.payload + * @param {onTokenProgress} params.onProgress + * @param {AbortController} params.abortController + */ + async chatCompletion({ payload, onProgress, abortController = null }) { + let intermediateReply = ''; + + const parameters = ollamaPayloadSchema.parse(payload); + const messages = OllamaClient.formatOpenAIMessages(payload.messages); + + if (parameters.stream) { + const stream = await this.client.chat({ + messages, + ...parameters, + }); + + for await (const chunk of stream) { + const token = chunk.message.content; + intermediateReply += token; + onProgress(token); + if (abortController.signal.aborted) { + stream.controller.abort(); + break; + } + } + } + // TODO: regular completion + else { + // const generation = await this.client.generate(payload); + } + + return intermediateReply; + } + catch(err) { + logger.error('[OllamaClient.chatCompletion]', err); + throw err; + } +} + +module.exports = { OllamaClient, ollamaPayloadSchema }; diff --git a/api/app/clients/OpenAIClient.js b/api/app/clients/OpenAIClient.js index 9abf36f588..97815d819e 100644 --- a/api/app/clients/OpenAIClient.js +++ b/api/app/clients/OpenAIClient.js @@ -1,4 +1,5 @@ const OpenAI = require('openai'); +const { OllamaClient } = require('./OllamaClient'); const { HttpsProxyAgent } = require('https-proxy-agent'); const { Constants, @@ -234,23 +235,52 @@ class OpenAIClient extends BaseClient { * @param {MongoFile[]} attachments */ checkVisionRequest(attachments) { - const availableModels = this.options.modelsConfig?.[this.options.endpoint]; - this.isVisionModel = validateVisionModel({ model: this.modelOptions.model, availableModels }); - - const visionModelAvailable = availableModels?.includes(this.defaultVisionModel); - if ( - attachments && - attachments.some((file) => file?.type && file?.type?.includes('image')) && - visionModelAvailable && - !this.isVisionModel - ) { - this.modelOptions.model = this.defaultVisionModel; - this.isVisionModel = true; + if (!attachments) { + return; } + const availableModels = this.options.modelsConfig?.[this.options.endpoint]; + if (!availableModels) { + return; + } + + let visionRequestDetected = false; + for (const file of attachments) { + if (file?.type?.includes('image')) { + visionRequestDetected = true; + break; + } + } + if (!visionRequestDetected) { + return; + } + + this.isVisionModel = validateVisionModel({ model: this.modelOptions.model, availableModels }); if (this.isVisionModel) { delete this.modelOptions.stop; + return; } + + for (const model of availableModels) { + if (!validateVisionModel({ model, availableModels })) { + continue; + } + this.modelOptions.model = model; + this.isVisionModel = true; + delete this.modelOptions.stop; + return; + } + + if (!availableModels.includes(this.defaultVisionModel)) { + return; + } + if (!validateVisionModel({ model: this.defaultVisionModel, availableModels })) { + return; + } + + this.modelOptions.model = this.defaultVisionModel; + this.isVisionModel = true; + delete this.modelOptions.stop; } setupTokens() { @@ -715,6 +745,10 @@ class OpenAIClient extends BaseClient { * In case of failure, it will return the default title, "New Chat". */ async titleConvo({ text, conversationId, responseText = '' }) { + if (this.options.attachments) { + delete this.options.attachments; + } + let title = 'New Chat'; const convo = `||>User: "${truncateText(text)}" @@ -1124,6 +1158,15 @@ ${convo} }); } + if (this.options.attachments && this.options.endpoint?.toLowerCase() === 'ollama') { + const ollamaClient = new OllamaClient({ baseURL }); + return await ollamaClient.chatCompletion({ + payload: modelOptions, + onProgress, + abortController, + }); + } + let UnexpectedRoleError = false; if (modelOptions.stream) { const stream = await openai.beta.chat.completions diff --git a/api/app/clients/specs/OpenAIClient.test.js b/api/app/clients/specs/OpenAIClient.test.js index 8c2226215c..7ef4fdcae5 100644 --- a/api/app/clients/specs/OpenAIClient.test.js +++ b/api/app/clients/specs/OpenAIClient.test.js @@ -157,12 +157,19 @@ describe('OpenAIClient', () => { azureOpenAIApiVersion: '2020-07-01-preview', }; + let originalWarn; + beforeAll(() => { - jest.spyOn(console, 'warn').mockImplementation(() => {}); + originalWarn = console.warn; + console.warn = jest.fn(); }); afterAll(() => { - console.warn.mockRestore(); + console.warn = originalWarn; + }); + + beforeEach(() => { + console.warn.mockClear(); }); beforeEach(() => { @@ -662,4 +669,35 @@ describe('OpenAIClient', () => { expect(constructorArgs.baseURL).toBe(expectedURL); }); }); + + describe('checkVisionRequest functionality', () => { + let client; + const attachments = [{ type: 'image/png' }]; + + beforeEach(() => { + client = new OpenAIClient('test-api-key', { + endpoint: 'ollama', + modelOptions: { + model: 'initial-model', + }, + modelsConfig: { + ollama: ['initial-model', 'llava', 'other-model'], + }, + }); + + client.defaultVisionModel = 'non-valid-default-model'; + }); + + afterEach(() => { + jest.restoreAllMocks(); + }); + + it('should set "llava" as the model if it is the first valid model when default validation fails', () => { + client.checkVisionRequest(attachments); + + expect(client.modelOptions.model).toBe('llava'); + expect(client.isVisionModel).toBeTruthy(); + expect(client.modelOptions.stop).toBeUndefined(); + }); + }); }); diff --git a/api/package.json b/api/package.json index 00761a1b1c..2bdbb303a7 100644 --- a/api/package.json +++ b/api/package.json @@ -75,6 +75,7 @@ "multer": "^1.4.5-lts.1", "nodejs-gpt": "^1.37.4", "nodemailer": "^6.9.4", + "ollama": "^0.5.0", "openai": "4.36.0", "openai-chat-tokens": "^0.2.8", "openid-client": "^5.4.2", diff --git a/api/server/services/ModelService.js b/api/server/services/ModelService.js index 540e7240a4..3c560b297b 100644 --- a/api/server/services/ModelService.js +++ b/api/server/services/ModelService.js @@ -2,60 +2,11 @@ const axios = require('axios'); const { HttpsProxyAgent } = require('https-proxy-agent'); const { EModelEndpoint, defaultModels, CacheKeys } = require('librechat-data-provider'); const { extractBaseURL, inputSchema, processModelData, logAxiosError } = require('~/utils'); +const { OllamaClient } = require('~/app/clients/OllamaClient'); const getLogStores = require('~/cache/getLogStores'); -const { logger } = require('~/config'); const { openAIApiKey, userProvidedOpenAI } = require('./Config/EndpointService').config; -/** - * Extracts the base URL from the provided URL. - * @param {string} fullURL - The full URL. - * @returns {string} The base URL. - */ -function deriveBaseURL(fullURL) { - try { - const parsedUrl = new URL(fullURL); - const protocol = parsedUrl.protocol; - const hostname = parsedUrl.hostname; - const port = parsedUrl.port; - - // Check if the parsed URL components are meaningful - if (!protocol || !hostname) { - return fullURL; - } - - // Reconstruct the base URL - return `${protocol}//${hostname}${port ? `:${port}` : ''}`; - } catch (error) { - logger.error('Failed to derive base URL', error); - return fullURL; // Return the original URL in case of any exception - } -} - -/** - * Fetches Ollama models from the specified base API path. - * @param {string} baseURL - * @returns {Promise} The Ollama models. - */ -const fetchOllamaModels = async (baseURL) => { - let models = []; - if (!baseURL) { - return models; - } - try { - const ollamaEndpoint = deriveBaseURL(baseURL); - /** @type {Promise>} */ - const response = await axios.get(`${ollamaEndpoint}/api/tags`); - models = response.data.models.map((tag) => tag.name); - return models; - } catch (error) { - const logMessage = - 'Failed to fetch models from Ollama API. If you are not using Ollama directly, and instead, through some aggregator or reverse proxy that handles fetching via OpenAI spec, ensure the name of the endpoint doesn\'t start with `ollama` (case-insensitive).'; - logger.error(logMessage, error); - return []; - } -}; - /** * Fetches OpenAI models from the specified base API path or Azure, based on the provided configuration. * @@ -92,7 +43,7 @@ const fetchModels = async ({ } if (name && name.toLowerCase().startsWith('ollama')) { - return await fetchOllamaModels(baseURL); + return await OllamaClient.fetchModels(baseURL); } try { @@ -281,7 +232,6 @@ const getGoogleModels = () => { module.exports = { fetchModels, - deriveBaseURL, getOpenAIModels, getChatGPTBrowserModels, getAnthropicModels, diff --git a/api/server/services/ModelService.spec.js b/api/server/services/ModelService.spec.js index 1abb152502..fc7c8b1079 100644 --- a/api/server/services/ModelService.spec.js +++ b/api/server/services/ModelService.spec.js @@ -1,7 +1,7 @@ const axios = require('axios'); const { logger } = require('~/config'); -const { fetchModels, getOpenAIModels, deriveBaseURL } = require('./ModelService'); +const { fetchModels, getOpenAIModels } = require('./ModelService'); jest.mock('~/utils', () => { const originalUtils = jest.requireActual('~/utils'); return { @@ -329,47 +329,3 @@ describe('fetchModels with Ollama specific logic', () => { ); }); }); - -describe('deriveBaseURL', () => { - it('should extract the base URL correctly from a full URL with a port', () => { - const fullURL = 'https://example.com:8080/path?query=123'; - const baseURL = deriveBaseURL(fullURL); - expect(baseURL).toEqual('https://example.com:8080'); - }); - - it('should extract the base URL correctly from a full URL without a port', () => { - const fullURL = 'https://example.com/path?query=123'; - const baseURL = deriveBaseURL(fullURL); - expect(baseURL).toEqual('https://example.com'); - }); - - it('should handle URLs using the HTTP protocol', () => { - const fullURL = 'http://example.com:3000/path?query=123'; - const baseURL = deriveBaseURL(fullURL); - expect(baseURL).toEqual('http://example.com:3000'); - }); - - it('should return only the protocol and hostname if no port is specified', () => { - const fullURL = 'http://example.com/path?query=123'; - const baseURL = deriveBaseURL(fullURL); - expect(baseURL).toEqual('http://example.com'); - }); - - it('should handle URLs with uncommon protocols', () => { - const fullURL = 'ftp://example.com:2121/path?query=123'; - const baseURL = deriveBaseURL(fullURL); - expect(baseURL).toEqual('ftp://example.com:2121'); - }); - - it('should handle edge case where URL ends with a slash', () => { - const fullURL = 'https://example.com/'; - const baseURL = deriveBaseURL(fullURL); - expect(baseURL).toEqual('https://example.com'); - }); - - it('should return the original URL if the URL is invalid', () => { - const invalidURL = 'htp:/example.com:8080'; - const result = deriveBaseURL(invalidURL); - expect(result).toBe(invalidURL); - }); -}); diff --git a/api/typedefs.js b/api/typedefs.js index df5e8be2be..f7970be4f3 100644 --- a/api/typedefs.js +++ b/api/typedefs.js @@ -7,6 +7,13 @@ * @typedef {import('openai').OpenAI} OpenAI * @memberof typedefs */ + +/** + * @exports Ollama + * @typedef {import('ollama').Ollama} Ollama + * @memberof typedefs + */ + /** * @exports AxiosResponse * @typedef {import('axios').AxiosResponse} AxiosResponse @@ -62,8 +69,14 @@ */ /** - * @exports ChatCompletionMessages - * @typedef {import('openai').OpenAI.ChatCompletionMessageParam} ChatCompletionMessages + * @exports OllamaMessage + * @typedef {import('ollama').Message} OllamaMessage + * @memberof typedefs + */ + +/** + * @exports ChatCompletionMessage + * @typedef {import('openai').OpenAI.ChatCompletionMessageParam} ChatCompletionMessage * @memberof typedefs */ @@ -1153,7 +1166,7 @@ /** * Main entrypoint for API completion calls * @callback sendCompletion - * @param {Array | string} payload - The messages or prompt to send to the model + * @param {Array | string} payload - The messages or prompt to send to the model * @param {object} opts - Options for the completion * @param {onTokenProgress} opts.onProgress - Callback function to handle token progress * @param {AbortController} opts.abortController - AbortController instance @@ -1164,7 +1177,7 @@ /** * Legacy completion handler for OpenAI API. * @callback getCompletion - * @param {Array | string} input - Array of messages or a single prompt string + * @param {Array | string} input - Array of messages or a single prompt string * @param {(event: object | string) => Promise} onProgress - SSE progress handler * @param {onTokenProgress} onTokenProgress - Token progress handler * @param {AbortController} [abortController] - AbortController instance diff --git a/api/utils/deriveBaseURL.js b/api/utils/deriveBaseURL.js new file mode 100644 index 0000000000..c377ddf874 --- /dev/null +++ b/api/utils/deriveBaseURL.js @@ -0,0 +1,28 @@ +const { logger } = require('~/config'); + +/** + * Extracts the base URL from the provided URL. + * @param {string} fullURL - The full URL. + * @returns {string} The base URL. + */ +function deriveBaseURL(fullURL) { + try { + const parsedUrl = new URL(fullURL); + const protocol = parsedUrl.protocol; + const hostname = parsedUrl.hostname; + const port = parsedUrl.port; + + // Check if the parsed URL components are meaningful + if (!protocol || !hostname) { + return fullURL; + } + + // Reconstruct the base URL + return `${protocol}//${hostname}${port ? `:${port}` : ''}`; + } catch (error) { + logger.error('Failed to derive base URL', error); + return fullURL; // Return the original URL in case of any exception + } +} + +module.exports = deriveBaseURL; diff --git a/api/utils/deriveBaseURL.spec.js b/api/utils/deriveBaseURL.spec.js new file mode 100644 index 0000000000..6df0bc65cd --- /dev/null +++ b/api/utils/deriveBaseURL.spec.js @@ -0,0 +1,74 @@ +const axios = require('axios'); +const deriveBaseURL = require('./deriveBaseURL'); +jest.mock('~/utils', () => { + const originalUtils = jest.requireActual('~/utils'); + return { + ...originalUtils, + processModelData: jest.fn((...args) => { + return originalUtils.processModelData(...args); + }), + }; +}); + +jest.mock('axios'); +jest.mock('~/cache/getLogStores', () => + jest.fn().mockImplementation(() => ({ + get: jest.fn().mockResolvedValue(undefined), + set: jest.fn().mockResolvedValue(true), + })), +); +jest.mock('~/config', () => ({ + logger: { + error: jest.fn(), + }, +})); + +axios.get.mockResolvedValue({ + data: { + data: [{ id: 'model-1' }, { id: 'model-2' }], + }, +}); + +describe('deriveBaseURL', () => { + it('should extract the base URL correctly from a full URL with a port', () => { + const fullURL = 'https://example.com:8080/path?query=123'; + const baseURL = deriveBaseURL(fullURL); + expect(baseURL).toEqual('https://example.com:8080'); + }); + + it('should extract the base URL correctly from a full URL without a port', () => { + const fullURL = 'https://example.com/path?query=123'; + const baseURL = deriveBaseURL(fullURL); + expect(baseURL).toEqual('https://example.com'); + }); + + it('should handle URLs using the HTTP protocol', () => { + const fullURL = 'http://example.com:3000/path?query=123'; + const baseURL = deriveBaseURL(fullURL); + expect(baseURL).toEqual('http://example.com:3000'); + }); + + it('should return only the protocol and hostname if no port is specified', () => { + const fullURL = 'http://example.com/path?query=123'; + const baseURL = deriveBaseURL(fullURL); + expect(baseURL).toEqual('http://example.com'); + }); + + it('should handle URLs with uncommon protocols', () => { + const fullURL = 'ftp://example.com:2121/path?query=123'; + const baseURL = deriveBaseURL(fullURL); + expect(baseURL).toEqual('ftp://example.com:2121'); + }); + + it('should handle edge case where URL ends with a slash', () => { + const fullURL = 'https://example.com/'; + const baseURL = deriveBaseURL(fullURL); + expect(baseURL).toEqual('https://example.com'); + }); + + it('should return the original URL if the URL is invalid', () => { + const invalidURL = 'htp:/example.com:8080'; + const result = deriveBaseURL(invalidURL); + expect(result).toBe(invalidURL); + }); +}); diff --git a/api/utils/index.js b/api/utils/index.js index 7b539cbb14..29357f7adb 100644 --- a/api/utils/index.js +++ b/api/utils/index.js @@ -1,6 +1,7 @@ const loadYaml = require('./loadYaml'); const tokenHelpers = require('./tokens'); const azureUtils = require('./azureUtils'); +const deriveBaseURL = require('./deriveBaseURL'); const logAxiosError = require('./logAxiosError'); const extractBaseURL = require('./extractBaseURL'); const findMessageContent = require('./findMessageContent'); @@ -9,6 +10,7 @@ module.exports = { loadYaml, ...tokenHelpers, ...azureUtils, + deriveBaseURL, logAxiosError, extractBaseURL, findMessageContent, diff --git a/package-lock.json b/package-lock.json index e1031dd324..10a5815ac2 100644 --- a/package-lock.json +++ b/package-lock.json @@ -83,6 +83,7 @@ "multer": "^1.4.5-lts.1", "nodejs-gpt": "^1.37.4", "nodemailer": "^6.9.4", + "ollama": "^0.5.0", "openai": "4.36.0", "openai-chat-tokens": "^0.2.8", "openid-client": "^5.4.2", @@ -21555,6 +21556,14 @@ "node": "^10.13.0 || >=12.0.0" } }, + "node_modules/ollama": { + "version": "0.5.0", + "resolved": "https://registry.npmjs.org/ollama/-/ollama-0.5.0.tgz", + "integrity": "sha512-CRtRzsho210EGdK52GrUMohA2pU+7NbgEaBG3DcYeRmvQthDO7E2LHOkLlUUeaYUlNmEd8icbjC02ug9meSYnw==", + "dependencies": { + "whatwg-fetch": "^3.6.20" + } + }, "node_modules/on-exit-leak-free": { "version": "2.1.2", "resolved": "https://registry.npmjs.org/on-exit-leak-free/-/on-exit-leak-free-2.1.2.tgz", @@ -28084,6 +28093,11 @@ "node": ">=0.10.0" } }, + "node_modules/whatwg-fetch": { + "version": "3.6.20", + "resolved": "https://registry.npmjs.org/whatwg-fetch/-/whatwg-fetch-3.6.20.tgz", + "integrity": "sha512-EqhiFU6daOA8kpjOWTL0olhVOF3i7OrFzSYiGsEMB8GcXS+RrzauAERX65xMeNWVqxA6HXH2m69Z9LaKKdisfg==" + }, "node_modules/whatwg-mimetype": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/whatwg-mimetype/-/whatwg-mimetype-3.0.0.tgz", @@ -29153,7 +29167,7 @@ }, "packages/data-provider": { "name": "librechat-data-provider", - "version": "0.5.9", + "version": "0.6.0", "license": "ISC", "dependencies": { "@types/js-yaml": "^4.0.9", diff --git a/packages/data-provider/src/config.ts b/packages/data-provider/src/config.ts index 1c2fb9bbfe..98af3a7454 100644 --- a/packages/data-provider/src/config.ts +++ b/packages/data-provider/src/config.ts @@ -459,13 +459,13 @@ export const supportsBalanceCheck = { }; export const visionModels = [ + 'gpt-4-turbo', 'gpt-4-vision', 'llava', 'llava-13b', 'gemini-pro-vision', 'claude-3', 'gemini-1.5', - 'gpt-4-turbo', ]; export enum VisionModes { generative = 'generative',