26
26
27
27
def trace_oci_genai (
28
28
client : "GenerativeAiInferenceClient" ,
29
+ estimate_tokens : bool = True ,
29
30
) -> "GenerativeAiInferenceClient" :
30
31
"""Patch the OCI Generative AI client to trace chat completions.
31
32
@@ -47,6 +48,9 @@ def trace_oci_genai(
47
48
----------
48
49
client : GenerativeAiInferenceClient
49
50
The OCI Generative AI client to patch.
51
+ estimate_tokens : bool, optional
52
+ Whether to estimate token counts when not provided by the OCI response.
53
+ Defaults to True. When False, token fields will be None if not available.
50
54
51
55
Returns
52
56
-------
@@ -84,6 +88,7 @@ def traced_chat_func(*args, **kwargs):
84
88
kwargs = kwargs ,
85
89
start_time = start_time ,
86
90
end_time = end_time ,
91
+ estimate_tokens = estimate_tokens ,
87
92
)
88
93
else :
89
94
return handle_non_streaming_chat (
@@ -92,6 +97,7 @@ def traced_chat_func(*args, **kwargs):
92
97
kwargs = kwargs ,
93
98
start_time = start_time ,
94
99
end_time = end_time ,
100
+ estimate_tokens = estimate_tokens ,
95
101
)
96
102
97
103
client .chat = traced_chat_func
@@ -104,6 +110,7 @@ def handle_streaming_chat(
104
110
kwargs : Dict [str , Any ],
105
111
start_time : float ,
106
112
end_time : float ,
113
+ estimate_tokens : bool = True ,
107
114
) -> Iterator [Any ]:
108
115
"""Handles the chat method when streaming is enabled.
109
116
@@ -127,6 +134,7 @@ def handle_streaming_chat(
127
134
kwargs = kwargs ,
128
135
start_time = start_time ,
129
136
end_time = end_time ,
137
+ estimate_tokens = estimate_tokens ,
130
138
)
131
139
132
140
@@ -136,6 +144,7 @@ def stream_chunks(
136
144
kwargs : Dict [str , Any ],
137
145
start_time : float ,
138
146
end_time : float ,
147
+ estimate_tokens : bool = True ,
139
148
):
140
149
"""Streams the chunks of the completion and traces the completion."""
141
150
collected_output_data = []
@@ -164,15 +173,18 @@ def stream_chunks(
164
173
usage = chunk .data .usage
165
174
num_of_prompt_tokens = getattr (usage , "prompt_tokens" , 0 )
166
175
else :
167
- # OCI doesn't provide usage info, estimate from chat_details
168
- num_of_prompt_tokens = estimate_prompt_tokens_from_chat_details (chat_details )
176
+ # OCI doesn't provide usage info, estimate from chat_details if enabled
177
+ if estimate_tokens :
178
+ num_of_prompt_tokens = estimate_prompt_tokens_from_chat_details (chat_details )
179
+ else :
180
+ num_of_prompt_tokens = None
169
181
170
182
# Store first chunk sample (only for debugging)
171
183
if hasattr (chunk , "data" ):
172
184
chunk_samples .append ({"index" : 0 , "type" : "first" })
173
185
174
- # Update completion tokens count
175
- if i > 0 :
186
+ # Update completion tokens count (estimation based)
187
+ if i > 0 and estimate_tokens :
176
188
num_of_completion_tokens = i + 1
177
189
178
190
# Fast content extraction - optimized for performance
@@ -208,8 +220,11 @@ def stream_chunks(
208
220
# chat_details is passed directly as parameter
209
221
model_id = extract_model_id (chat_details )
210
222
211
- # Calculate total tokens
212
- total_tokens = (num_of_prompt_tokens or 0 ) + (num_of_completion_tokens or 0 )
223
+ # Calculate total tokens - handle None values properly
224
+ if estimate_tokens :
225
+ total_tokens = (num_of_prompt_tokens or 0 ) + (num_of_completion_tokens or 0 )
226
+ else :
227
+ total_tokens = None if num_of_prompt_tokens is None and num_of_completion_tokens is None else ((num_of_prompt_tokens or 0 ) + (num_of_completion_tokens or 0 ))
213
228
214
229
# Simplified metadata - only essential timing info
215
230
metadata = {
@@ -222,8 +237,8 @@ def stream_chunks(
222
237
output = output_data ,
223
238
latency = latency ,
224
239
tokens = total_tokens ,
225
- prompt_tokens = num_of_prompt_tokens or 0 ,
226
- completion_tokens = num_of_completion_tokens or 0 ,
240
+ prompt_tokens = num_of_prompt_tokens ,
241
+ completion_tokens = num_of_completion_tokens ,
227
242
model = model_id ,
228
243
model_parameters = get_model_parameters (chat_details ),
229
244
raw_output = {
@@ -251,6 +266,7 @@ def handle_non_streaming_chat(
251
266
kwargs : Dict [str , Any ],
252
267
start_time : float ,
253
268
end_time : float ,
269
+ estimate_tokens : bool = True ,
254
270
) -> Any :
255
271
"""Handles the chat method when streaming is disabled.
256
272
@@ -274,7 +290,7 @@ def handle_non_streaming_chat(
274
290
try :
275
291
# Parse response and extract data
276
292
output_data = parse_non_streaming_output_data (response )
277
- tokens_info = extract_tokens_info (response , chat_details )
293
+ tokens_info = extract_tokens_info (response , chat_details , estimate_tokens )
278
294
model_id = extract_model_id (chat_details )
279
295
280
296
latency = (end_time - start_time ) * 1000
@@ -287,9 +303,9 @@ def handle_non_streaming_chat(
287
303
inputs = extract_inputs_from_chat_details (chat_details ),
288
304
output = output_data ,
289
305
latency = latency ,
290
- tokens = tokens_info .get ("total_tokens" , 0 ),
291
- prompt_tokens = tokens_info .get ("input_tokens" , 0 ),
292
- completion_tokens = tokens_info .get ("output_tokens" , 0 ),
306
+ tokens = tokens_info .get ("total_tokens" ),
307
+ prompt_tokens = tokens_info .get ("input_tokens" ),
308
+ completion_tokens = tokens_info .get ("output_tokens" ),
293
309
model = model_id ,
294
310
model_parameters = get_model_parameters (chat_details ),
295
311
raw_output = response .data .__dict__ if hasattr (response , "data" ) else response .__dict__ ,
@@ -472,10 +488,10 @@ def parse_non_streaming_output_data(response) -> Union[str, Dict[str, Any], None
472
488
return str (data )
473
489
474
490
475
- def estimate_prompt_tokens_from_chat_details (chat_details ) -> int :
491
+ def estimate_prompt_tokens_from_chat_details (chat_details ) -> Optional [ int ] :
476
492
"""Estimate prompt tokens from chat details when OCI doesn't provide usage info."""
477
493
if not chat_details :
478
- return 10 # Fallback estimate
494
+ return None
479
495
480
496
try :
481
497
input_text = ""
@@ -491,72 +507,107 @@ def estimate_prompt_tokens_from_chat_details(chat_details) -> int:
491
507
return estimated_tokens
492
508
except Exception as e :
493
509
logger .debug ("Error estimating prompt tokens: %s" , e )
494
- return 10 # Fallback estimate
510
+ return None
495
511
496
512
497
- def extract_tokens_info (response , chat_details = None ) -> Dict [str , int ]:
498
- """Extract token usage information from the response."""
499
- tokens_info = {"input_tokens" : 0 , "output_tokens" : 0 , "total_tokens" : 0 }
513
+ def extract_tokens_info (response , chat_details = None , estimate_tokens : bool = True ) -> Dict [str , Optional [int ]]:
514
+ """Extract token usage information from the response.
515
+
516
+ Handles both CohereChatResponse and GenericChatResponse types from OCI.
517
+
518
+ Parameters
519
+ ----------
520
+ response : Any
521
+ The OCI chat response object (CohereChatResponse or GenericChatResponse)
522
+ chat_details : Any, optional
523
+ The chat details for token estimation if needed
524
+ estimate_tokens : bool, optional
525
+ Whether to estimate tokens when not available in response. Defaults to True.
526
+
527
+ Returns
528
+ -------
529
+ Dict[str, Optional[int]]
530
+ Dictionary with token counts. Values can be None if unavailable and estimation disabled.
531
+ """
532
+ tokens_info = {"input_tokens" : None , "output_tokens" : None , "total_tokens" : None }
500
533
501
534
try :
502
- # First, try the standard locations for token usage
535
+ # Extract token usage from OCI response (handles both CohereChatResponse and GenericChatResponse)
503
536
if hasattr (response , "data" ):
504
- # Check multiple possible locations for usage info
505
- usage_locations = [
506
- getattr (response .data , "usage" , None ),
507
- getattr (getattr (response .data , "chat_response" , None ), "usage" , None ),
508
- ]
509
-
510
- for usage in usage_locations :
511
- if usage is not None :
512
- tokens_info ["input_tokens" ] = getattr (usage , "prompt_tokens" , 0 )
513
- tokens_info ["output_tokens" ] = getattr (usage , "completion_tokens" , 0 )
514
- tokens_info ["total_tokens" ] = tokens_info ["input_tokens" ] + tokens_info ["output_tokens" ]
515
- logger .debug ("Found token usage info: %s" , tokens_info )
516
- return tokens_info
517
-
518
- # If no usage info found, estimate based on text length
519
- # This is common for OCI which doesn't return token counts
520
- logger .debug ("No token usage found in response, estimating from text length" )
537
+ usage = None
538
+
539
+ # For CohereChatResponse: response.data.usage
540
+ if hasattr (response .data , "usage" ):
541
+ usage = response .data .usage
542
+ # For GenericChatResponse: response.data.chat_response.usage
543
+ elif hasattr (response .data , "chat_response" ) and hasattr (response .data .chat_response , "usage" ):
544
+ usage = response .data .chat_response .usage
545
+
546
+ if usage is not None :
547
+ # Extract tokens from usage object
548
+ prompt_tokens = getattr (usage , "prompt_tokens" , None )
549
+ completion_tokens = getattr (usage , "completion_tokens" , None )
550
+ total_tokens = getattr (usage , "total_tokens" , None )
551
+
552
+ tokens_info ["input_tokens" ] = prompt_tokens
553
+ tokens_info ["output_tokens" ] = completion_tokens
554
+ tokens_info ["total_tokens" ] = total_tokens or (
555
+ (prompt_tokens + completion_tokens ) if prompt_tokens is not None and completion_tokens is not None else None
556
+ )
557
+ logger .debug ("Found token usage info: %s" , tokens_info )
558
+ return tokens_info
521
559
522
- # Estimate input tokens from chat_details
523
- if chat_details :
560
+ # If no usage info found, estimate based on text length only if estimation is enabled
561
+ if estimate_tokens :
562
+ logger .debug ("No token usage found in response, estimating from text length" )
563
+
564
+ # Estimate input tokens from chat_details
565
+ if chat_details :
566
+ try :
567
+ input_text = ""
568
+ if hasattr (chat_details , "chat_request" ) and hasattr (chat_details .chat_request , "messages" ):
569
+ for msg in chat_details .chat_request .messages :
570
+ if hasattr (msg , "content" ) and msg .content :
571
+ for content_item in msg .content :
572
+ if hasattr (content_item , "text" ):
573
+ input_text += content_item .text + " "
574
+
575
+ # Rough estimation: ~4 characters per token
576
+ estimated_input_tokens = max (1 , len (input_text ) // 4 )
577
+ tokens_info ["input_tokens" ] = estimated_input_tokens
578
+ except Exception as e :
579
+ logger .debug ("Error estimating input tokens: %s" , e )
580
+ tokens_info ["input_tokens" ] = None
581
+
582
+ # Estimate output tokens from response
524
583
try :
525
- input_text = ""
526
- if hasattr (chat_details , "chat_request" ) and hasattr (chat_details .chat_request , "messages" ):
527
- for msg in chat_details .chat_request .messages :
528
- if hasattr (msg , "content" ) and msg .content :
529
- for content_item in msg .content :
530
- if hasattr (content_item , "text" ):
531
- input_text += content_item .text + " "
532
-
533
- # Rough estimation: ~4 characters per token
534
- estimated_input_tokens = max (1 , len (input_text ) // 4 )
535
- tokens_info ["input_tokens" ] = estimated_input_tokens
584
+ output_text = parse_non_streaming_output_data (response )
585
+ if isinstance (output_text , str ):
586
+ # Rough estimation: ~4 characters per token
587
+ estimated_output_tokens = max (1 , len (output_text ) // 4 )
588
+ tokens_info ["output_tokens" ] = estimated_output_tokens
589
+ else :
590
+ tokens_info ["output_tokens" ] = None
536
591
except Exception as e :
537
- logger .debug ("Error estimating input tokens: %s" , e )
538
- tokens_info ["input_tokens " ] = 10 # Fallback estimate
592
+ logger .debug ("Error estimating output tokens: %s" , e )
593
+ tokens_info ["output_tokens " ] = None
539
594
540
- # Estimate output tokens from response
541
- try :
542
- output_text = parse_non_streaming_output_data (response )
543
- if isinstance (output_text , str ):
544
- # Rough estimation: ~4 characters per token
545
- estimated_output_tokens = max (1 , len (output_text ) // 4 )
546
- tokens_info ["output_tokens" ] = estimated_output_tokens
595
+ # Calculate total tokens only if we have estimates
596
+ if tokens_info ["input_tokens" ] is not None and tokens_info ["output_tokens" ] is not None :
597
+ tokens_info ["total_tokens" ] = tokens_info ["input_tokens" ] + tokens_info ["output_tokens" ]
598
+ elif tokens_info ["input_tokens" ] is not None or tokens_info ["output_tokens" ] is not None :
599
+ tokens_info ["total_tokens" ] = (tokens_info ["input_tokens" ] or 0 ) + (tokens_info ["output_tokens" ] or 0 )
547
600
else :
548
- tokens_info ["output_tokens" ] = 5 # Fallback estimate
549
- except Exception as e :
550
- logger .debug ("Error estimating output tokens: %s" , e )
551
- tokens_info ["output_tokens" ] = 5 # Fallback estimate
552
-
553
- tokens_info ["total_tokens" ] = tokens_info ["input_tokens" ] + tokens_info ["output_tokens" ]
554
- logger .debug ("Estimated token usage: %s" , tokens_info )
601
+ tokens_info ["total_tokens" ] = None
602
+
603
+ logger .debug ("Estimated token usage: %s" , tokens_info )
604
+ else :
605
+ logger .debug ("No token usage found in response and estimation disabled, returning None values" )
555
606
556
607
except Exception as e :
557
608
logger .debug ("Error extracting/estimating token info: %s" , e )
558
- # Provide minimal fallback estimates
559
- tokens_info = {"input_tokens" : 10 , "output_tokens" : 5 , "total_tokens" : 15 }
609
+ # Always return None values on exceptions (no more fallback values)
610
+ tokens_info = {"input_tokens" : None , "output_tokens" : None , "total_tokens" : None }
560
611
561
612
return tokens_info
562
613
0 commit comments