token counting fixes

This commit is contained in:
Dhanji Prasanna
2025-10-09 12:11:14 +11:00
parent 9d1eef82b9
commit 260c949576
7 changed files with 283 additions and 22 deletions

View File

@@ -291,11 +291,12 @@ impl DatabricksProvider {
&self,
mut stream: impl futures_util::Stream<Item = reqwest::Result<Bytes>> + Unpin,
tx: mpsc::Sender<Result<CompletionChunk>>,
) {
) -> Option<Usage> {
let mut buffer = String::new();
let mut current_tool_calls: std::collections::HashMap<usize, (String, String, String)> =
std::collections::HashMap::new(); // index -> (id, name, args)
let mut incomplete_data_line = String::new(); // Buffer for incomplete data: lines
let accumulated_usage: Option<Usage> = None;
while let Some(chunk_result) = stream.next().await {
match chunk_result {
@@ -318,7 +319,7 @@ impl DatabricksProvider {
let _ = tx
.send(Err(anyhow!("Invalid UTF-8 in stream chunk: {}", e)))
.await;
return;
return accumulated_usage;
}
};
@@ -377,6 +378,7 @@ impl DatabricksProvider {
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
usage: accumulated_usage.clone(),
tool_calls: if final_tool_calls.is_empty() {
None
} else {
@@ -386,7 +388,7 @@ impl DatabricksProvider {
if tx.send(Ok(final_chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
}
return;
return accumulated_usage;
}
// Debug: Log every raw JSON payload from Databricks API (truncated for large payloads)
@@ -410,11 +412,12 @@ impl DatabricksProvider {
let chunk = CompletionChunk {
content,
finished: false,
usage: None,
tool_calls: None,
};
if tx.send(Ok(chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
return;
return accumulated_usage;
}
}
@@ -502,6 +505,7 @@ impl DatabricksProvider {
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
usage: accumulated_usage.clone(),
tool_calls: if final_tool_calls.is_empty() {
None
} else {
@@ -511,7 +515,7 @@ impl DatabricksProvider {
if tx.send(Ok(final_chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
}
return;
return accumulated_usage;
}
}
}
@@ -553,7 +557,7 @@ impl DatabricksProvider {
Err(e) => {
error!("Stream error: {}", e);
let _ = tx.send(Err(anyhow!("Stream error: {}", e))).await;
return;
return accumulated_usage;
}
}
}
@@ -592,6 +596,7 @@ impl DatabricksProvider {
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
usage: accumulated_usage.clone(),
tool_calls: if final_tool_calls.is_empty() {
None
} else {
@@ -599,6 +604,7 @@ impl DatabricksProvider {
},
};
let _ = tx.send(Ok(final_chunk)).await;
accumulated_usage
}
pub async fn fetch_supported_models(&mut self) -> Result<Option<Vec<String>>> {