Skip to content
This repository was archived by the owner on May 19, 2025. It is now read-only.

Refactor to agent #2

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions agent/data_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from typing import TypedDict, Annotated, Sequence, List, Optional
import operator
from langchain_openai import AzureChatOpenAI
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver

from tools.scraping import gen_queries, get_video_ids, download, VideoInfo
from tools.video_chunking import detect_segments, SegmentInfo
from tools.annotating import extract_clues, gen_annotations

from tools.prompts import (
GEN_QUERIES_PROMPT,
EXTRACT_CLUES_PROMPT,
GEN_ANNOTATIONS_PROMPT,
)


llm = AzureChatOpenAI(
temperature=0.0,
azure_deployment="gpt4o",
openai_api_version="2023-07-01-preview",
)

memory = MemorySaver()
# memory = SqliteSaver.from_conn_string(":memory:")


class AgentState(TypedDict):
task: str
search_queries: List[str]
video_ids: List[str]
video_infos: List[VideoInfo]
clip_text_prompts: List[str]
segment_infos: List[SegmentInfo]
clues: List[str]
annotations: List[str]


class DataAgent:
def __init__(self, llm, memory):
self.llm = llm
self.memory = memory
self.graph = self.build_graph()

def build_graph(self):
builder = StateGraph(AgentState)

builder.add_node("generate_queries", self.gen_queries_node)
builder.add_node("get_video_ids", self.get_video_ids_node)
builder.add_node("download", self.download_node)
builder.add_node("detect_segments", self.detect_segments_node)
builder.add_node("extract_clues", self.extract_clues_node)
builder.add_node("gen_annotations", self.gen_annotations_node)

builder.set_entry_point("generate_queries")

builder.add_edge("generate_queries", "get_video_ids")
builder.add_edge("get_video_ids", "download")
builder.add_edge("download", "detect_segments")
builder.add_edge("detect_segments", "extract_clues")
builder.add_edge("extract_clues", "gen_annotations")
builder.add_edge("gen_annotations", END)

graph = builder.compile(checkpointer=memory)

return graph

def gen_queries_node(self, state: AgentState):
search_queries = gen_queries(self.llm, state["task"], GEN_QUERIES_PROMPT)
return {"search_queries": search_queries[:2]}

def get_video_ids_node(self, state: AgentState):
video_ids = get_video_ids(state["search_queries"])
return {"video_ids": video_ids}

def download_node(self, state: AgentState):
video_infos = download(state["video_ids"])
return {"video_infos": video_infos}

def detect_segments_node(self, state: AgentState):
segment_infos = detect_segments(
state["video_infos"], state["clip_text_prompts"]
)
return {"segment_infos": segment_infos}

def extract_clues_node(self, state: AgentState):
clues = extract_clues(
self.llm,
EXTRACT_CLUES_PROMPT,
state["segment_infos"],
state["video_infos"],
)
return {"clues": clues}

def gen_annotations_node(self, state: AgentState):
annotations = gen_annotations(self.llm, GEN_ANNOTATIONS_PROMPT, state["clues"])
return {"annotations": annotations}

def run(self, task: str, thread_id: str):
thread = {"configurable": {"thread_id": thread_id}}
for step in self.graph.stream(
{
"task": task,
"clip_text_prompts": ["person doing squats"],
},
thread,
):
if "download" in step:
print("dowload happened")
elif "extract_clues" in step:
print("extract_clues happened")
else:
print(step)
98 changes: 98 additions & 0 deletions agent/run_agent.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from dotenv import load_dotenv\n",
"_ = load_dotenv()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"torch.device(\"cuda:0\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from data_agent import DataAgent"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain_openai import AzureChatOpenAI\n",
"from langgraph.checkpoint.memory import MemorySaver\n",
"\n",
"\n",
"llm = AzureChatOpenAI(\n",
" temperature=0.0,\n",
" azure_deployment=\"gpt4o\",\n",
" openai_api_version=\"2023-07-01-preview\",\n",
")\n",
"\n",
"memory = MemorySaver()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"agent = DataAgent(llm, memory)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"agent.run(\"i wanna teach people how to do squats\", thread_id=\"1\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "vlm_databuilder_agent",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
181 changes: 181 additions & 0 deletions agent/tools/annotating.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
from typing import List, Optional
from collections import defaultdict
from langchain.pydantic_v1 import BaseModel, Field
from langchain_core.prompts import ChatPromptTemplate

# 4. Create nodes

from .scraping import VideoInfo
from .video_chunking import SegmentInfo


class LocalClue(BaseModel):
"""Local clues for a segment"""

id: str = Field(description="LC1,LC2...")
quote: str = Field(
description="the quote from the transcript that was used to create this clue."
)
quote_timestamp_start: str = Field(
description="the exact start timestamp of the quote."
)
quote_timestamp_end: str = Field(
description="the exact end timestamp of the quote."
)
clue: str = Field(description="the main clue data")


class GlobalClue(BaseModel):
"""Global clues for a segment"""

id: str = Field(description="GC1,GC2...")
quote: str = Field(
description="the quote from the transcript that was used to create this clue."
)
quote_timestamp_start: str = Field(
description="the exact start timestamp of the quote."
)
quote_timestamp_end: str = Field(
description="the exact end timestamp of the quote."
)
clue: str = Field(description="the main clue data.")
relevance_to_segment: str = Field(
description="why do you think this global clue is relevant to the segment you are working with right now."
)


class LogicalInference(BaseModel):
"""Logical inferences for a segment"""

id: str = Field(description="LI1,LI2,...")
description: str = Field(description="A concise form of the logical inference.")
details: str = Field(
description="A verbose explanation of what insight about what happens in this segment should be made based on the clues that you found."
)


class SegmentAnnotation(BaseModel):
local_clues: list[LocalClue] = Field(
description="Local clues are inside the segment in terms of timestamps."
)
global_clues: list[GlobalClue] = Field(
description="Global clues are scattered across the entire transcript."
)
logical_inferences: list[LogicalInference] = Field(
description="What can we infer about the topic, that the user is looking for in the video, can we make based on the clues inside this segment"
)


class SegmentWithClueInfo(BaseModel):
"""
Annotation for a video segment.
"""

start_timestamp: str = Field(
description="start timestamp of the segment in format HH:MM:SS.MS"
)
end_timestamp: str = Field(
description="start timestamp of the segment in format HH:MM:SS.MS"
)
segment_annotation: SegmentAnnotation = Field(
description="list of annotations for the segment"
)


class VideoAnnotation(BaseModel):
"""
Segments of a video.
"""

segments: list[SegmentWithClueInfo] = Field(
description="information about each segment"
)


def extract_clues(
llm,
system_prompt: str,
segment_infos: List[SegmentInfo],
video_infos: List[VideoInfo],
):

prompt_template = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
(
"user",
"Segment timecodes: {{ segment_timecodes }}\nTranscript: {{ transcript }}",
),
],
template_format="jinja2",
)

model = prompt_template | llm.with_structured_output(VideoAnnotation)

segment_infos_dict = defaultdict(list)
for segment_info in segment_infos:
segment_infos_dict[segment_info.video_id].append(segment_info)

video_infos_dict = {video_info.video_id: video_info for video_info in video_infos}

clues = []

for video_id, segment_infos in segment_infos_dict.items():
transcript = video_infos_dict[video_id].transcript
segment_infos_chunks = [
segment_infos[i : i + 5] for i in range(0, len(segment_infos), 5)
]

for chunk in segment_infos_chunks:
video_annotation: VideoAnnotation = model.invoke(
{
"segment_timecodes": "\n".join(
[f"{s.start_timestamp}-{s.end_timestamp}" for s in chunk]
),
"transcript": transcript,
}
)
clues.extend(video_annotation.segments)

return clues


def gen_annotations(llm, system_prompt: str, clues: List[SegmentAnnotation]):
class SegmentFeedback(BaseModel):
right: Optional[str] = Field(description="what was right in the performance")
wrong: Optional[str] = Field(description="what was wrong in the performance")
correction: Optional[str] = Field(
description="how and in what ways it the performance could be improved"
)

# The segment timestamps are taken from the provided information.
class SegmentCompleteAnnotation(BaseModel):
squats_probability: Optional[str] = Field(
description="how high is the probability that the person is doing squats in the segment: low, medium, high, unknown(null)"
)
squats_technique_correctness: Optional[str] = Field(
description="correctness of the squat technique."
)
squats_feedback: Optional[SegmentFeedback] = Field(
description="what was right and wrong in the squat perfomance in the segment. When the technique is incorrect, provide instructions how to correct them."
)

prompt_template = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
("user", "Clues: {{ clues }}"),
],
template_format="jinja2",
)

model = prompt_template | llm.with_structured_output(SegmentCompleteAnnotation)

annotations = []
for clue in clues:
segment_annotation: SegmentCompleteAnnotation = model.invoke(
{"clues": clue.json()}
)

annotations.append(segment_annotation.json())

return annotations
Loading