Skip to content

Commit bad6a7d

Browse files
committed
Update mistral model cache mecachnisme
1 parent 10115d8 commit bad6a7d

File tree

3 files changed

+54
-131
lines changed

3 files changed

+54
-131
lines changed

lua/codecompanion/adapters/http/mistral/get_models.lua

Lines changed: 46 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,30 @@ local M = {}
1616
---@field formatted_name string?
1717
---@field opts {can_use_tools: boolean, has_vision: boolean}
1818

19-
---@alias MistralModelCache table<string, MistralModelInfo>
20-
21-
---@type MistralModelCache
19+
---@class MistralApiCapabilities
20+
---@field completion_chat boolean
21+
---@field vision boolean
22+
---@field function_calling boolean
23+
24+
---@class MistralApiModel
25+
---@field id string
26+
---@field name string
27+
---@field aliases string[]
28+
---@field deprecation any?
29+
---@field capabilities MistralApiCapabilities
30+
31+
---@type table<string, MistralModelInfo>
2232
local _cached_models = {}
2333

24-
---@return MistralModelCache
34+
---@return table<string, MistralModelInfo>
2535
local function get_cached_models()
26-
assert(_cached_models ~= nil, "Model info is not available in the cache.")
27-
local models = _cached_models
28-
return models
36+
return _cached_models
2937
end
3038

3139
---When given a list of names that are aliases for the same model, returns the preferred name.
3240
---The preference order is: names ending with '-latest' (highest priority),
3341
---then names ending with four digits (e.g., '-2023'), and finally other names.
34-
---@param names table List of names, should at least contain 1 entry
42+
---@param names string[] List of names, should at least contain 1 entry
3543
---@return string?
3644
local function preferred_model_name(names)
3745
local high_score = -1
@@ -55,16 +63,13 @@ end
5563

5664
---Multiple id can refer to the same Model,
5765
---This function removes duplicates, only using preferred model name
58-
---@param models table[] Table as returned by Mistral API response
59-
---@return table[]
66+
---@param models MistralApiModel[] Table as returned by Mistral API response
67+
---@return MistralApiModel[]
6068
local function dedup_models(models)
6169
local preferred_names = {}
6270
for _, model in ipairs(models) do
6371
if model.id then
64-
local aliases = {}
65-
if model.aliases then
66-
aliases = model.aliases
67-
end
72+
local aliases = model.aliases
6873
table.insert(aliases, model.id)
6974

7075
if preferred_names[model.id] then
@@ -92,25 +97,27 @@ end
9297

9398
---Fetch model list and model info.
9499
---Aborts if there's another fetch job running.
95-
---Returns the number of models if the fetches are fired.
100+
---@return boolean cache was successful updated
96101
---@param adapter CodeCompanion.HTTPAdapter Mistral adapter with env var replaced.
97102
local function fetch_async(adapter)
98103
assert(adapter ~= nil)
104+
105+
utils.get_env_vars(adapter)
99106
if running then
100-
return
107+
return false
101108
end
102109

103110
running = true
104111

105-
_cached_models = _cached_models or {} -- TODO: this has the side effect models are never removed
112+
_cached_models = _cached_models or {}
106113

107114
local models_endpoint = "/v1/models"
108115
local headers = {
109116
["content-type"] = "application/json",
110117
["Authorization"] = "Bearer " .. adapter.env_replaced.api_key,
111118
}
112119
local url = adapter.env_replaced.url
113-
pcall(function()
120+
local ok, err = pcall(function()
114121
Curl.get(url .. models_endpoint, {
115122
headers = headers,
116123
insecure = config.adapters.http.opts.allow_insecure,
@@ -120,18 +127,19 @@ local function fetch_async(adapter)
120127
-- This can happen wen you update vim ui in curl callback.
121128
callback = vim.schedule_wrap(function(response)
122129
if response.status ~= 200 then
123-
running = false
124-
return log:error(
130+
log:error(
125131
"Could not get Mistral models from " .. url .. models_endpoint .. ". Error: %s",
126132
response.body
127133
)
134+
running = false
135+
return false
128136
end
129137

130138
local ok, json = pcall(vim.json.decode, response.body)
131139
if not ok then
132-
running = false
133140
log:error("Could not parse the response from " .. url .. models_endpoint)
134-
return {}
141+
running = false
142+
return false
135143
end
136144

137145
for _, model_obj in ipairs(dedup_models(json.data)) do
@@ -150,39 +158,29 @@ local function fetch_async(adapter)
150158
running = false
151159
end),
152160
})
153-
if adapter.opts.cache_adapter == false then
154-
vim.wait(CONSTANTS.TIMEOUT, function()
155-
local models = _cached_models
156-
return models ~= nil and not vim.tbl_isempty(models) and not running
157-
end)
158-
end
159161
end)
162+
163+
if not ok then
164+
log:error("Could not fetch fetch Mistral Copilot models: %s", err)
165+
running = false
166+
return false
167+
end
168+
return true
160169
end
161170

162171
---@param self CodeCompanion.HTTPAdapter
163-
---@return MistralModelCache
164-
function M.choices(self, opts)
165-
local adapter = require("codecompanion.adapters.http").resolve(self)
172+
---@return table<string, MistralModelInfo>
173+
function M.choices(self)
166174

167-
if not adapter then
168-
log:error("Could not resolve Mistral adapter in the `choices` function")
169-
return {}
175+
local models = get_cached_models()
176+
if models ~= nil and next(models) then
177+
return models
170178
end
171-
opts = opts or { async = true }
172-
173-
utils.get_env_vars(adapter)
174-
local is_uninitialised = _cached_models == nil or next(_cached_models)
175-
176-
local should_block = (adapter.opts.cache_adapter == false) or is_uninitialised or not opts.async
177179

178-
fetch_async(adapter)
179-
180-
if should_block and running then
181-
vim.wait(CONSTANTS.TIMEOUT, function()
182-
return not running
183-
end)
184-
end
180+
fetch_async(self)
181+
vim.wait(CONSTANTS.TIMEOUT, function()
182+
return not running
183+
end)
185184
return get_cached_models()
186185
end
187-
188186
return M

lua/codecompanion/adapters/http/mistral/init.lua

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ local openai = require("codecompanion.adapters.http.openai")
33

44
---Resolves the options that a model has
55
---@param adapter CodeCompanion.HTTPAdapter
6-
---@return table
6+
---@return MistralModelInfo | nil
77
local function resolve_model_opts(adapter)
88
local model = adapter.schema.model.default
99
local choices = adapter.schema.model.choices
@@ -28,7 +28,6 @@ return {
2828
stream = true,
2929
tools = true,
3030
vision = true,
31-
cache_adapter = true, -- Cache the resolved adapter to prevent multiple resolutions
3231
},
3332
features = {
3433
text = true,
@@ -44,6 +43,9 @@ return {
4443
["Content-Type"] = "application/json",
4544
},
4645
handlers = {
46+
47+
---@param self CodeCompanion.HTTPAdapter
48+
---@return boolean
4749
setup = function(self)
4850
local model_opts = resolve_model_opts(self)
4951

@@ -104,9 +106,9 @@ return {
104106
type = "enum",
105107
desc = "ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.",
106108
default = "mistral-small-latest",
107-
---@type fun(self: CodeCompanion.HTTPAdapter, opts?: table): table
108-
choices = function(self, opts)
109-
return get_models.choices(self, opts)
109+
---@type fun(self: CodeCompanion.HTTPAdapter): table<string, MistralModelInfo>
110+
choices = function(self)
111+
return get_models.choices(self)
110112
end,
111113
},
112114
temperature = {

tests/adapters/http/mistral/test_models.lua

Lines changed: 1 addition & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
local h = require("tests.helpers")
33

44
local new_set = MiniTest.new_set
5-
local expect = MiniTest.expect
6-
local eq = MiniTest.expect.equality
75

86
local child = MiniTest.new_child_neovim()
97

@@ -56,7 +54,7 @@ T["mistral.models"]["choices() synchronous returns expected models"] = function(
5654
adapters_utils.get_env_vars = function(adapter) end
5755
5856
local adapter = { opts = {} }
59-
return get_models.choices(adapter, { async = false })
57+
return get_models.choices(adapter)
6058
]])
6159

6260
-- This expected output is based on the logic in get_models.lua:
@@ -93,79 +91,4 @@ T["mistral.models"]["choices() synchronous returns expected models"] = function(
9391
h.eq(result, expected)
9492
end
9593

96-
T["mistral.models"]["choices() async populates cache and returns later"] = function()
97-
local first, second = unpack(child.lua([[
98-
local get_models = require("codecompanion.adapters.http.mistral.get_models")
99-
100-
-- Mock Curl.get to return stub data
101-
local curl = require("plenary.curl")
102-
local body = vim.fn.readfile("tests/adapters/http/stubs/mistral_models.json")
103-
body = table.concat(body, "\n")
104-
105-
curl.get = function(url, opts)
106-
if opts and type(opts.callback) == "function" then
107-
opts.callback({ status = 200, body = body })
108-
end
109-
return { status = 200, body = body }
110-
end
111-
112-
-- Mock resolve() to return test adapter
113-
local http_adapters = require("codecompanion.adapters.http")
114-
http_adapters.resolve = function(self)
115-
return {
116-
env_replaced = {
117-
url = "https://api.mistral.ai",
118-
api_key = "test-key",
119-
},
120-
opts = {},
121-
}
122-
end
123-
124-
-- Mock get_env_vars() to do nothing
125-
local adapters_utils = require("codecompanion.utils.adapters")
126-
adapters_utils.get_env_vars = function(adapter) end
127-
128-
local adapter = { opts = {} }
129-
130-
-- Start async fetch: should return nil initially (no cache yet)
131-
local first = get_models.choices(adapter, { async = true })
132-
133-
-- Give scheduled callback a chance to run and fill cache
134-
vim.wait(50, function() return false end)
135-
136-
-- Second call should return cached models
137-
local second = get_models.choices(adapter, { async = true })
138-
139-
return { first, second }
140-
]]))
141-
142-
local expected = {
143-
["mistral-medium-2505"] = {
144-
formatted_name = "mistral-medium-2505",
145-
opts = {
146-
has_vision = true,
147-
can_use_tools = true,
148-
},
149-
},
150-
["mistral-large-latest"] = {
151-
formatted_name = "mistral-large-latest",
152-
opts = {
153-
has_vision = true,
154-
can_use_tools = true,
155-
},
156-
},
157-
["ministral-8b-latest"] = {
158-
formatted_name = "ministral-8b-2410",
159-
opts = {
160-
has_vision = false,
161-
can_use_tools = true,
162-
},
163-
},
164-
}
165-
166-
-- Mistral blocks on first call when cache is uninitialized, even with async=true
167-
h.eq(expected, first)
168-
h.eq(expected, second)
169-
end
170-
17194
return T

0 commit comments

Comments
 (0)