Skip to content

Commit 488ba7c

Browse files
viniciusdsmellogustavocidornelas
authored andcommitted
feat(tracing): enhance OCI tracing functionality with token estimation options
- Updated the `trace_oci_genai` function to include an optional `estimate_tokens` parameter, allowing users to control token estimation behavior when not provided by OCI responses. - Enhanced the `oci_genai_tracing.ipynb` notebook to document the new parameter and its implications for token estimation, improving user understanding and experience. - Modified the `extract_tokens_info` function to handle token estimation more robustly, returning None for token fields when estimation is disabled. - Ensured all changes comply with coding standards, including comprehensive type annotations and Google-style docstrings for maintainability.
1 parent 2e02aa2 commit 488ba7c

File tree

2 files changed

+132
-69
lines changed

2 files changed

+132
-69
lines changed

examples/tracing/oci/oci_genai_tracing.ipynb

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,13 @@
113113
"source": [
114114
"## Apply Openlayer Tracing\n",
115115
"\n",
116-
"Wrap the OCI client with Openlayer tracing to automatically capture all interactions.\n"
116+
"Wrap the OCI client with Openlayer tracing to automatically capture all interactions.\n",
117+
"\n",
118+
"The `trace_oci_genai()` function accepts an optional `estimate_tokens` parameter:\n",
119+
"- `estimate_tokens=True` (default): Estimates token counts when not provided by OCI response\n",
120+
"- `estimate_tokens=False`: Returns None for token fields when not available in the response\n",
121+
"\n",
122+
"OCI responses can be either CohereChatResponse or GenericChatResponse, both containing usage information when available.\n"
117123
]
118124
},
119125
{
@@ -123,7 +129,13 @@
123129
"outputs": [],
124130
"source": [
125131
"# Apply Openlayer tracing to the OCI client\n",
126-
"traced_client = trace_oci_genai(client)"
132+
"# With token estimation enabled (default)\n",
133+
"traced_client = trace_oci_genai(client, estimate_tokens=True)\n",
134+
"\n",
135+
"# Alternative: Disable token estimation to get None values when tokens are not available\n",
136+
"# traced_client = trace_oci_genai(client, estimate_tokens=False)\n",
137+
"\n",
138+
"print(\"Openlayer OCI tracer applied successfully!\")"
127139
]
128140
},
129141
{

src/openlayer/lib/integrations/oci_tracer.py

Lines changed: 118 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
def trace_oci_genai(
2828
client: "GenerativeAiInferenceClient",
29+
estimate_tokens: bool = True,
2930
) -> "GenerativeAiInferenceClient":
3031
"""Patch the OCI Generative AI client to trace chat completions.
3132
@@ -47,6 +48,9 @@ def trace_oci_genai(
4748
----------
4849
client : GenerativeAiInferenceClient
4950
The OCI Generative AI client to patch.
51+
estimate_tokens : bool, optional
52+
Whether to estimate token counts when not provided by the OCI response.
53+
Defaults to True. When False, token fields will be None if not available.
5054
5155
Returns
5256
-------
@@ -84,6 +88,7 @@ def traced_chat_func(*args, **kwargs):
8488
kwargs=kwargs,
8589
start_time=start_time,
8690
end_time=end_time,
91+
estimate_tokens=estimate_tokens,
8792
)
8893
else:
8994
return handle_non_streaming_chat(
@@ -92,6 +97,7 @@ def traced_chat_func(*args, **kwargs):
9297
kwargs=kwargs,
9398
start_time=start_time,
9499
end_time=end_time,
100+
estimate_tokens=estimate_tokens,
95101
)
96102

97103
client.chat = traced_chat_func
@@ -104,6 +110,7 @@ def handle_streaming_chat(
104110
kwargs: Dict[str, Any],
105111
start_time: float,
106112
end_time: float,
113+
estimate_tokens: bool = True,
107114
) -> Iterator[Any]:
108115
"""Handles the chat method when streaming is enabled.
109116
@@ -127,6 +134,7 @@ def handle_streaming_chat(
127134
kwargs=kwargs,
128135
start_time=start_time,
129136
end_time=end_time,
137+
estimate_tokens=estimate_tokens,
130138
)
131139

132140

@@ -136,6 +144,7 @@ def stream_chunks(
136144
kwargs: Dict[str, Any],
137145
start_time: float,
138146
end_time: float,
147+
estimate_tokens: bool = True,
139148
):
140149
"""Streams the chunks of the completion and traces the completion."""
141150
collected_output_data = []
@@ -164,15 +173,18 @@ def stream_chunks(
164173
usage = chunk.data.usage
165174
num_of_prompt_tokens = getattr(usage, "prompt_tokens", 0)
166175
else:
167-
# OCI doesn't provide usage info, estimate from chat_details
168-
num_of_prompt_tokens = estimate_prompt_tokens_from_chat_details(chat_details)
176+
# OCI doesn't provide usage info, estimate from chat_details if enabled
177+
if estimate_tokens:
178+
num_of_prompt_tokens = estimate_prompt_tokens_from_chat_details(chat_details)
179+
else:
180+
num_of_prompt_tokens = None
169181

170182
# Store first chunk sample (only for debugging)
171183
if hasattr(chunk, "data"):
172184
chunk_samples.append({"index": 0, "type": "first"})
173185

174-
# Update completion tokens count
175-
if i > 0:
186+
# Update completion tokens count (estimation based)
187+
if i > 0 and estimate_tokens:
176188
num_of_completion_tokens = i + 1
177189

178190
# Fast content extraction - optimized for performance
@@ -208,8 +220,11 @@ def stream_chunks(
208220
# chat_details is passed directly as parameter
209221
model_id = extract_model_id(chat_details)
210222

211-
# Calculate total tokens
212-
total_tokens = (num_of_prompt_tokens or 0) + (num_of_completion_tokens or 0)
223+
# Calculate total tokens - handle None values properly
224+
if estimate_tokens:
225+
total_tokens = (num_of_prompt_tokens or 0) + (num_of_completion_tokens or 0)
226+
else:
227+
total_tokens = None if num_of_prompt_tokens is None and num_of_completion_tokens is None else ((num_of_prompt_tokens or 0) + (num_of_completion_tokens or 0))
213228

214229
# Simplified metadata - only essential timing info
215230
metadata = {
@@ -222,8 +237,8 @@ def stream_chunks(
222237
output=output_data,
223238
latency=latency,
224239
tokens=total_tokens,
225-
prompt_tokens=num_of_prompt_tokens or 0,
226-
completion_tokens=num_of_completion_tokens or 0,
240+
prompt_tokens=num_of_prompt_tokens,
241+
completion_tokens=num_of_completion_tokens,
227242
model=model_id,
228243
model_parameters=get_model_parameters(chat_details),
229244
raw_output={
@@ -251,6 +266,7 @@ def handle_non_streaming_chat(
251266
kwargs: Dict[str, Any],
252267
start_time: float,
253268
end_time: float,
269+
estimate_tokens: bool = True,
254270
) -> Any:
255271
"""Handles the chat method when streaming is disabled.
256272
@@ -274,7 +290,7 @@ def handle_non_streaming_chat(
274290
try:
275291
# Parse response and extract data
276292
output_data = parse_non_streaming_output_data(response)
277-
tokens_info = extract_tokens_info(response, chat_details)
293+
tokens_info = extract_tokens_info(response, chat_details, estimate_tokens)
278294
model_id = extract_model_id(chat_details)
279295

280296
latency = (end_time - start_time) * 1000
@@ -287,9 +303,9 @@ def handle_non_streaming_chat(
287303
inputs=extract_inputs_from_chat_details(chat_details),
288304
output=output_data,
289305
latency=latency,
290-
tokens=tokens_info.get("total_tokens", 0),
291-
prompt_tokens=tokens_info.get("input_tokens", 0),
292-
completion_tokens=tokens_info.get("output_tokens", 0),
306+
tokens=tokens_info.get("total_tokens"),
307+
prompt_tokens=tokens_info.get("input_tokens"),
308+
completion_tokens=tokens_info.get("output_tokens"),
293309
model=model_id,
294310
model_parameters=get_model_parameters(chat_details),
295311
raw_output=response.data.__dict__ if hasattr(response, "data") else response.__dict__,
@@ -472,10 +488,10 @@ def parse_non_streaming_output_data(response) -> Union[str, Dict[str, Any], None
472488
return str(data)
473489

474490

475-
def estimate_prompt_tokens_from_chat_details(chat_details) -> int:
491+
def estimate_prompt_tokens_from_chat_details(chat_details) -> Optional[int]:
476492
"""Estimate prompt tokens from chat details when OCI doesn't provide usage info."""
477493
if not chat_details:
478-
return 10 # Fallback estimate
494+
return None
479495

480496
try:
481497
input_text = ""
@@ -491,72 +507,107 @@ def estimate_prompt_tokens_from_chat_details(chat_details) -> int:
491507
return estimated_tokens
492508
except Exception as e:
493509
logger.debug("Error estimating prompt tokens: %s", e)
494-
return 10 # Fallback estimate
510+
return None
495511

496512

497-
def extract_tokens_info(response, chat_details=None) -> Dict[str, int]:
498-
"""Extract token usage information from the response."""
499-
tokens_info = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
513+
def extract_tokens_info(response, chat_details=None, estimate_tokens: bool = True) -> Dict[str, Optional[int]]:
514+
"""Extract token usage information from the response.
515+
516+
Handles both CohereChatResponse and GenericChatResponse types from OCI.
517+
518+
Parameters
519+
----------
520+
response : Any
521+
The OCI chat response object (CohereChatResponse or GenericChatResponse)
522+
chat_details : Any, optional
523+
The chat details for token estimation if needed
524+
estimate_tokens : bool, optional
525+
Whether to estimate tokens when not available in response. Defaults to True.
526+
527+
Returns
528+
-------
529+
Dict[str, Optional[int]]
530+
Dictionary with token counts. Values can be None if unavailable and estimation disabled.
531+
"""
532+
tokens_info = {"input_tokens": None, "output_tokens": None, "total_tokens": None}
500533

501534
try:
502-
# First, try the standard locations for token usage
535+
# Extract token usage from OCI response (handles both CohereChatResponse and GenericChatResponse)
503536
if hasattr(response, "data"):
504-
# Check multiple possible locations for usage info
505-
usage_locations = [
506-
getattr(response.data, "usage", None),
507-
getattr(getattr(response.data, "chat_response", None), "usage", None),
508-
]
509-
510-
for usage in usage_locations:
511-
if usage is not None:
512-
tokens_info["input_tokens"] = getattr(usage, "prompt_tokens", 0)
513-
tokens_info["output_tokens"] = getattr(usage, "completion_tokens", 0)
514-
tokens_info["total_tokens"] = tokens_info["input_tokens"] + tokens_info["output_tokens"]
515-
logger.debug("Found token usage info: %s", tokens_info)
516-
return tokens_info
517-
518-
# If no usage info found, estimate based on text length
519-
# This is common for OCI which doesn't return token counts
520-
logger.debug("No token usage found in response, estimating from text length")
537+
usage = None
538+
539+
# For CohereChatResponse: response.data.usage
540+
if hasattr(response.data, "usage"):
541+
usage = response.data.usage
542+
# For GenericChatResponse: response.data.chat_response.usage
543+
elif hasattr(response.data, "chat_response") and hasattr(response.data.chat_response, "usage"):
544+
usage = response.data.chat_response.usage
545+
546+
if usage is not None:
547+
# Extract tokens from usage object
548+
prompt_tokens = getattr(usage, "prompt_tokens", None)
549+
completion_tokens = getattr(usage, "completion_tokens", None)
550+
total_tokens = getattr(usage, "total_tokens", None)
551+
552+
tokens_info["input_tokens"] = prompt_tokens
553+
tokens_info["output_tokens"] = completion_tokens
554+
tokens_info["total_tokens"] = total_tokens or (
555+
(prompt_tokens + completion_tokens) if prompt_tokens is not None and completion_tokens is not None else None
556+
)
557+
logger.debug("Found token usage info: %s", tokens_info)
558+
return tokens_info
521559

522-
# Estimate input tokens from chat_details
523-
if chat_details:
560+
# If no usage info found, estimate based on text length only if estimation is enabled
561+
if estimate_tokens:
562+
logger.debug("No token usage found in response, estimating from text length")
563+
564+
# Estimate input tokens from chat_details
565+
if chat_details:
566+
try:
567+
input_text = ""
568+
if hasattr(chat_details, "chat_request") and hasattr(chat_details.chat_request, "messages"):
569+
for msg in chat_details.chat_request.messages:
570+
if hasattr(msg, "content") and msg.content:
571+
for content_item in msg.content:
572+
if hasattr(content_item, "text"):
573+
input_text += content_item.text + " "
574+
575+
# Rough estimation: ~4 characters per token
576+
estimated_input_tokens = max(1, len(input_text) // 4)
577+
tokens_info["input_tokens"] = estimated_input_tokens
578+
except Exception as e:
579+
logger.debug("Error estimating input tokens: %s", e)
580+
tokens_info["input_tokens"] = None
581+
582+
# Estimate output tokens from response
524583
try:
525-
input_text = ""
526-
if hasattr(chat_details, "chat_request") and hasattr(chat_details.chat_request, "messages"):
527-
for msg in chat_details.chat_request.messages:
528-
if hasattr(msg, "content") and msg.content:
529-
for content_item in msg.content:
530-
if hasattr(content_item, "text"):
531-
input_text += content_item.text + " "
532-
533-
# Rough estimation: ~4 characters per token
534-
estimated_input_tokens = max(1, len(input_text) // 4)
535-
tokens_info["input_tokens"] = estimated_input_tokens
584+
output_text = parse_non_streaming_output_data(response)
585+
if isinstance(output_text, str):
586+
# Rough estimation: ~4 characters per token
587+
estimated_output_tokens = max(1, len(output_text) // 4)
588+
tokens_info["output_tokens"] = estimated_output_tokens
589+
else:
590+
tokens_info["output_tokens"] = None
536591
except Exception as e:
537-
logger.debug("Error estimating input tokens: %s", e)
538-
tokens_info["input_tokens"] = 10 # Fallback estimate
592+
logger.debug("Error estimating output tokens: %s", e)
593+
tokens_info["output_tokens"] = None
539594

540-
# Estimate output tokens from response
541-
try:
542-
output_text = parse_non_streaming_output_data(response)
543-
if isinstance(output_text, str):
544-
# Rough estimation: ~4 characters per token
545-
estimated_output_tokens = max(1, len(output_text) // 4)
546-
tokens_info["output_tokens"] = estimated_output_tokens
595+
# Calculate total tokens only if we have estimates
596+
if tokens_info["input_tokens"] is not None and tokens_info["output_tokens"] is not None:
597+
tokens_info["total_tokens"] = tokens_info["input_tokens"] + tokens_info["output_tokens"]
598+
elif tokens_info["input_tokens"] is not None or tokens_info["output_tokens"] is not None:
599+
tokens_info["total_tokens"] = (tokens_info["input_tokens"] or 0) + (tokens_info["output_tokens"] or 0)
547600
else:
548-
tokens_info["output_tokens"] = 5 # Fallback estimate
549-
except Exception as e:
550-
logger.debug("Error estimating output tokens: %s", e)
551-
tokens_info["output_tokens"] = 5 # Fallback estimate
552-
553-
tokens_info["total_tokens"] = tokens_info["input_tokens"] + tokens_info["output_tokens"]
554-
logger.debug("Estimated token usage: %s", tokens_info)
601+
tokens_info["total_tokens"] = None
602+
603+
logger.debug("Estimated token usage: %s", tokens_info)
604+
else:
605+
logger.debug("No token usage found in response and estimation disabled, returning None values")
555606

556607
except Exception as e:
557608
logger.debug("Error extracting/estimating token info: %s", e)
558-
# Provide minimal fallback estimates
559-
tokens_info = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
609+
# Always return None values on exceptions (no more fallback values)
610+
tokens_info = {"input_tokens": None, "output_tokens": None, "total_tokens": None}
560611

561612
return tokens_info
562613

0 commit comments

Comments
 (0)