token counting fixes
This commit is contained in:
@@ -291,11 +291,12 @@ impl DatabricksProvider {
|
||||
&self,
|
||||
mut stream: impl futures_util::Stream<Item = reqwest::Result<Bytes>> + Unpin,
|
||||
tx: mpsc::Sender<Result<CompletionChunk>>,
|
||||
) {
|
||||
) -> Option<Usage> {
|
||||
let mut buffer = String::new();
|
||||
let mut current_tool_calls: std::collections::HashMap<usize, (String, String, String)> =
|
||||
std::collections::HashMap::new(); // index -> (id, name, args)
|
||||
let mut incomplete_data_line = String::new(); // Buffer for incomplete data: lines
|
||||
let accumulated_usage: Option<Usage> = None;
|
||||
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
match chunk_result {
|
||||
@@ -318,7 +319,7 @@ impl DatabricksProvider {
|
||||
let _ = tx
|
||||
.send(Err(anyhow!("Invalid UTF-8 in stream chunk: {}", e)))
|
||||
.await;
|
||||
return;
|
||||
return accumulated_usage;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -377,6 +378,7 @@ impl DatabricksProvider {
|
||||
let final_chunk = CompletionChunk {
|
||||
content: String::new(),
|
||||
finished: true,
|
||||
usage: accumulated_usage.clone(),
|
||||
tool_calls: if final_tool_calls.is_empty() {
|
||||
None
|
||||
} else {
|
||||
@@ -386,7 +388,7 @@ impl DatabricksProvider {
|
||||
if tx.send(Ok(final_chunk)).await.is_err() {
|
||||
debug!("Receiver dropped, stopping stream");
|
||||
}
|
||||
return;
|
||||
return accumulated_usage;
|
||||
}
|
||||
|
||||
// Debug: Log every raw JSON payload from Databricks API (truncated for large payloads)
|
||||
@@ -410,11 +412,12 @@ impl DatabricksProvider {
|
||||
let chunk = CompletionChunk {
|
||||
content,
|
||||
finished: false,
|
||||
usage: None,
|
||||
tool_calls: None,
|
||||
};
|
||||
if tx.send(Ok(chunk)).await.is_err() {
|
||||
debug!("Receiver dropped, stopping stream");
|
||||
return;
|
||||
return accumulated_usage;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -502,6 +505,7 @@ impl DatabricksProvider {
|
||||
let final_chunk = CompletionChunk {
|
||||
content: String::new(),
|
||||
finished: true,
|
||||
usage: accumulated_usage.clone(),
|
||||
tool_calls: if final_tool_calls.is_empty() {
|
||||
None
|
||||
} else {
|
||||
@@ -511,7 +515,7 @@ impl DatabricksProvider {
|
||||
if tx.send(Ok(final_chunk)).await.is_err() {
|
||||
debug!("Receiver dropped, stopping stream");
|
||||
}
|
||||
return;
|
||||
return accumulated_usage;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -553,7 +557,7 @@ impl DatabricksProvider {
|
||||
Err(e) => {
|
||||
error!("Stream error: {}", e);
|
||||
let _ = tx.send(Err(anyhow!("Stream error: {}", e))).await;
|
||||
return;
|
||||
return accumulated_usage;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -592,6 +596,7 @@ impl DatabricksProvider {
|
||||
let final_chunk = CompletionChunk {
|
||||
content: String::new(),
|
||||
finished: true,
|
||||
usage: accumulated_usage.clone(),
|
||||
tool_calls: if final_tool_calls.is_empty() {
|
||||
None
|
||||
} else {
|
||||
@@ -599,6 +604,7 @@ impl DatabricksProvider {
|
||||
},
|
||||
};
|
||||
let _ = tx.send(Ok(final_chunk)).await;
|
||||
accumulated_usage
|
||||
}
|
||||
|
||||
pub async fn fetch_supported_models(&mut self) -> Result<Option<Vec<String>>> {
|
||||
|
||||
Reference in New Issue
Block a user