Skip to content

Commit

Permalink
[Chore] update ai model (#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
jyoo0515 authored Feb 21, 2024
1 parent b8ee5ba commit 3189890
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 91 deletions.
18 changes: 9 additions & 9 deletions ai/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

pro_input = {
"user-audio": "example1.m4a",
"practice-sentence": "It's autumn now, and the leaves are turning beautiful colors.",
"practice-sentece": "It's autumn now, and the leaves are turning beautiful colors.",
"tip": "Say 'aw-tum,' not 'ay-tum.'"
}

Expand All @@ -23,16 +23,16 @@
pro_feedback = ft.ProFeedback(pro_input)

## Sample Output ##
print("** Feedback for Contextual Task: ", com_feedback)
# {'positive-feedback': 'You are very creative! I like your imagination.',
print("** Feedback for Communication Task: ", com_feedback)
# {'positive-feedback': 'You are very creative! I like your imagination.',
# 'negative-feedback': "Let's try to describe what we see in the picture. First, look at the sky. What colors can you see there?",
# 'enhanced-answer': 'In the sky, I can see yellow, orange, pink, and blue.'}

print("** Feedback for Pronunciation Task: ", pro_feedback)
# {'transcription': 'ITS AUTUMN NOW AND THE LEAVES ARE TURNING BEAUTIFUL COLORS',
# 'wrong_idx': {'minor': [2, 9], 'major': []},
# 'pronunciation_score': 0.7,
# 'decibel': 46.90759735625882,
# 'speech_rate': 2.347417840375587,
# 'positive-feedback': 'Pronunciation is correct. Keep up the good work!',
# {'transcription': 'ITS AUTUMN NOW AND THE LEAVES ARE TURNING BEAUTIFUL COLORS',
# 'wrong_idx': {'minor': [2, 9], 'major': []},
# 'pronunciation_score': 0.7,
# 'decibel': 46.90759735625882,
# 'speech_rate': 2.347417840375587,
# 'positive-feedback': 'Pronunciation is correct. Keep up the good work!',
# 'negative-feedback': ' '}
105 changes: 45 additions & 60 deletions ai/fluentify.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,13 @@
import librosa
import numpy as np
import stable_whisper
import torch
import vertexai
import yaml
from evaluate import load
from transformers import AutoTokenizer, AutoFeatureExtractor, AutoModelForCTC
from vertexai.preview.generative_models import GenerativeModel, Part

from utils.utils import text2dict
from utils.word_process import get_incorrect_idx
from utils.word_process import get_incorrect_idxes


# gcloud auth application-default login
Expand All @@ -25,21 +23,19 @@ def __init__(self):
self.lang_model = GenerativeModel("gemini-pro")
self.current_path = os.path.dirname(__file__)
self.gcs_path = "gs://fluentify-412312.appspot.com"
self.shared_path = "./shared-data"
self.ars_whisper_model = stable_whisper.load_model('base')
self.ars_model = stable_whisper.load_model('base')

with open(os.path.join(self.current_path, 'data/prompt.yaml'), 'r', encoding='UTF-8') as file:
self.prompt = yaml.load(file, Loader=yaml.FullLoader)

self.audio_path = self.shared_path + "/audio"
self.tokenizer = AutoTokenizer.from_pretrained("facebook/wav2vec2-base-960h")
self.ars_w2v_model = AutoModelForCTC.from_pretrained("facebook/wav2vec2-base-960h")
self.feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h",
sampling_rate=16000)
self.audio_path = "./data/audio"
# self.tokenizer = AutoTokenizer.from_pretrained("facebook/wav2vec2-base-960h")
# self.ars_w2v_model = AutoModelForCTC.from_pretrained("facebook/wav2vec2-base-960h")
# self.feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h", sampling_rate=16000)
self.wer = load("wer")

self.speech_rate_threshold_h = 2.5
self.speech_rate_threshold_l = 1.0
self.speech_rate_threshold_h = 4.5
self.speech_rate_threshold_l = 2.5
self.decibel_threshold_h = 95
self.decibel_threshold_l = 45

Expand All @@ -49,37 +45,26 @@ def __init__(self):
def GetWer(self, transcription, ground_truth):
# WER score is from 0 to 1
wer_score = self.wer.compute(predictions=[transcription], references=[ground_truth])
incorrect_idx = get_incorrect_idx(ground_truth, transcription)
incorrect_idx = {"minor": incorrect_idx["substitutions"], "major": incorrect_idx["deletions"]}
return wer_score, incorrect_idx

def ASR(self, audio_path):
## TODO : Audio File Path ####
input_audio = librosa.load(audio_path, sr=22000)[0]
input_values = self.feature_extractor(input_audio, return_tensors="pt", padding="longest").input_values
rms = librosa.feature.rms(y=input_audio)
incorrect_idxes = get_incorrect_idxes(ground_truth, transcription)
self.wer_score = wer_score
self.incorrect_idxes = incorrect_idxes["substitutions"] + incorrect_idxes["deletions"]

with torch.no_grad():
logits = self.ars_w2v_model(input_values).logits[0]
pred_ids = torch.argmax(logits, axis=-1)
outputs = self.tokenizer.decode(pred_ids, output_word_offsets=True)
time_offset = self.ars_w2v_model.config.inputs_to_logits_ratio / self.feature_extractor.sampling_rate

transcription = outputs[0]
word_offsets = [
{
"word": d["word"],
"start_time": round(d["start_offset"] * time_offset, 2),
"end_time": round(d["end_offset"] * time_offset, 2),
"time": round(d["end_offset"] * time_offset - d["start_offset"] * time_offset, 2),
}
for d in outputs.word_offsets
]

time = word_offsets[-1]["end_time"]
speech_rate = len(word_offsets) / word_offsets[-1]["end_time"]
def GetDecibel(self, audio_path):
input_audio, _ = librosa.load(audio_path, sr=22000)
rms = librosa.feature.rms(y=input_audio)
decibel = 20 * math.log10(np.mean(rms) / 0.00002)
return transcription, time, speech_rate, decibel
return decibel

def EvaluatePro(self, audio_path):
transcription = self.ars_model.transcribe(audio_path)
self.pro_transcription = transcription.text

word_offset = transcription.segments
start, end = word_offset[0].start, word_offset[-1].end
duration = end - start

self.speech_rate = len(self.pro_transcription.split(' ')) / duration
self.decibel = self.GetDecibel(audio_path)

def Score2Des(self, output):
feedback = {
Expand Down Expand Up @@ -113,11 +98,12 @@ def ScoreMap(self, input, threshold):
return 5

def ComFeedback(self, input):
image = Part.from_uri(f"{self.gcs_path}/{input['img']}", mime_type="image/jpeg")
response = self.ars_whisper_model.transcribe(f"{self.audio_path}/{input['user-audio']}")
image = Part.from_uri(f"{self.gcs_path}/img/{input['img']}", mime_type="image/jpeg")
response = self.ars_model.transcribe(f"{self.audio_path}/{input['user-audio']}")
self.com_trnascription = response.text
output = {"transcription": self.com_trnascription}

output = {"transcription": response.text}
#### TODO : Feedback focusing more on sentence formulation #####
#### TODO : Feedback focusing more on sentence formulation #####
prompt = f"{self.prompt['con-feedback']}".format(context=input["context"], question=input["question"],
answer=input["answer"], user_answer=output['transcription'])
response = self.multimodal_model.generate_content([prompt, image])
Expand All @@ -131,30 +117,29 @@ def ComFeedback(self, input):
return None

def ProFeedback(self, input):
ground_truth = input["practice-sentence"]
transcription, time, speech_rate, decibel = self.ASR(f"{self.audio_path}/{input['user-audio']}")
wer_score, incorrect_idx = self.GetWer(transcription.upper(), ground_truth.upper())

ground_truth = input["practice-sentece"]
self.EvaluatePro(f"{self.audio_path}/{input['user-audio']}")
self.GetWer(self.pro_transcription.upper(), ground_truth.upper())
sentence_lst = ground_truth.split(" ")

pronunciation_score = self.ScoreMap(1 - wer_score, 1)
speed_socre = self.ScoreMap(speech_rate, self.speech_rate_threshold)
volume_score = self.ScoreMap(decibel, self.decibel_threshold)
pronunciation_score = self.ScoreMap(1 - self.wer_score, 1)
speed_score = self.ScoreMap(self.speech_rate, self.speech_rate_threshold)
volume_score = self.ScoreMap(self.decibel, self.decibel_threshold)

output = {
"transcription": transcription,
"incorrect_indexes": incorrect_idx['major'],
"transcription": self.pro_transcription,
"incorrect_indexes": self.incorrect_idxes,
"decibel": self.decibel,
"speech_rate": self.speech_rate,
"pronunciation_score": pronunciation_score, # higher the better
"voulume_score": volume_score, # higher the better
"speed_socre": speed_socre,
"decibel": decibel,
"speech_rate": speech_rate
"volume_score": volume_score, # higher the better
"speed_score": speed_score
}

feedback = self.Score2Des(output)

## TODO : Decide whether to use only major errors or only minor errors ##
# only if there is major incorrect pronunciation
## TODO : Decide whether to use only major errors or only minor errors ##
# only if there is major incorrect pronunciation
if len(output["incorrect_indexes"]) > 0:
for idx in output["incorrect_indexes"]:
feedback['incorrect_pronunciation'] = f"Your pronunciation for {sentence_lst[idx]} is not correct"
Expand Down
56 changes: 43 additions & 13 deletions ai/talking-video-gen.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
# Imports the Google Cloud stt library
import json
import os
import time
from datetime import datetime

import cv2
import pandas as pd
import stable_whisper
from PIL import Image
from google.cloud import texttospeech
from moviepy.editor import VideoFileClip, AudioFileClip
from tqdm import tqdm


# gcloud auth application-default login
Expand All @@ -26,16 +25,17 @@ def __init__(self):
self.data_path = './data'
self.character_path = f'{self.data_path}/character'

self.ts = datetime.now().strftime("%m%d%H%M%S")
self.dir_path = f'{self.data_path}/video_output/{self.ts}'
# self.ts = datetime.now().strftime("%m%d%H%M%S")
# self.dir_path = f'{self.data_path}/video_output/{self.ts}'
self.dir_path = f'{self.data_path}/pro-video-output'

self.audio_path = self.dir_path + '/output.mp3'
self.video_path = self.dir_path + "/output.mp4"
self.wordoffset_path = self.dir_path + '/word_offset.json'
self.output_path = self.dir_path + '/fianal-output.mp4'

if not os.path.exists(self.dir_path):
os.makedirs(self.dir_path)
# if not os.path.exists(self.dir_path):
# os.makedirs(self.dir_path)

img_lst = os.listdir(self.character_path)
self.close_img_lst = [img for img in img_lst if 'close' in img]
Expand Down Expand Up @@ -96,7 +96,7 @@ def VideoGen(self):
open_img = self.close_img_lst[0]

for word in self.wordoffset.iloc:
## CLOSE ##
## CLOSE ##
if word['content'] == None:
# close_img = random.choice(self.close_img_lst)
for _ in range(int(word['time'])):
Expand All @@ -122,12 +122,14 @@ def VideoGen(self):
video.release()

def MergeAudioVideo(self):

video_clip = VideoFileClip(self.video_path)
audio_clip = AudioFileClip(self.audio_path)
final_clip = video_clip.set_audio(audio_clip)
final_clip.write_videofile(self.output_path)

def Text2TalkingVideo(self, sentence):
def Text2TalkingVideo(self, sentence, fname="feedback"):
self.output_path = f'{self.dir_path}/{fname}.mp4'
self.Text2Speech(sentence)
self.Speech2Text()
self.WordOffset()
Expand All @@ -138,9 +140,37 @@ def Text2TalkingVideo(self, sentence):

if __name__ == "__main__":
tg = TalkingGenerator()
sentence = "The most important thing to keep in mind when designing this car would be to make sure that it is made of a metal that the superhero can control with his mind."

with open('./data/pro-data.json') as f:
pro_dataset = json.load(f)

with open('./data/com-data.json') as f:
com_dataset = json.load(f)

for i, pro_data in tqdm(enumerate(pro_dataset)):
pro_data.update({'practice-sentence-w-tip': pro_data['practice-sentence'] + ' ' + pro_data['tip']})
sentence = pro_data["practice-sentence-w-tip"]
fname = pro_data["id"]
tg.Text2TalkingVideo(sentence, fname)

# for i,com_data in tqdm(enumerate(com_dataset)):
# sentence = com_data["question"]
# fname = com_data["id"]
# tg.Text2TalkingVideo(sentence, fname)

with open('./data/pro-data.json', 'w') as f:
json.dump(pro_dataset, f, indent=4)

# with open('./data/com-data.json', 'w') as f:
# json.dump(com_dataset, f, indent=4)
# for com_data in tqdm(com_dataset):
# sentence = com_data['question']
# fname = com_data['img'].split('.')[0]
# tg.Text2TalkingVideo(sentence, fname)

# sentence ="The most important thing to keep in mind when designing this car would be to make sure that it is made of a metal that the superhero can control with his mind."
# sentence = "Let's imagine that you are a brave captain of a big ship. You are sailing on the high seas. Suddenly, you see a beautiful sunset. Look at this picture and tell me..."
time_ = time.time()
tg.Text2TalkingVideo(sentence)
time_ = time.time() - time_
print(time_)
# time_ = time.time()
# tg.Text2TalkingVideo(sentence=,fname=)
# time_ = time.time() - time_
# print(time_)
48 changes: 46 additions & 2 deletions ai/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
import ast
import json


def text2dict(text):
Expand All @@ -10,4 +10,48 @@ def text2dict(text):
try:
return ast.literal_eval(text)
except:
return json.loads(text) if text else None
try:
return json.loads(text)
except:
text = text.split('positive-feedback')[-1]
if 'enhanced-answer' in text:
pos, neg_w_enh = text.split('negative-feedback')
neg, enh = neg_w_enh.split('enhanced-answer')
pos, neg, enh = clean(pos), clean(neg), clean(enh)
return {"positive-feedback": pos, "negative-feedback": neg, "enhanced-answer": enh}
else:
pos, neg = text.split('negative-feedback')
pos, neg = clean(pos), clean(neg)
return {"positive-feedback": pos, "negative-feedback": neg}


def clean(text):
text = text.strip()
if text.startswith('"') or text.startswith("'"):
text = text[1:]
if text.endswith('"') or text.endswith("'"):
text = text[:-1]

if ":" in text:
text = text.replace(":", "")
if "," in text:
text = text.replace(",", "")
if "{" in text:
text = text.replace("{", "")
if "}" in text:
text = text.replace("}", "")

if "\'s" in text:
text = text.replace("\'s", " 's")
if "\'m" in text:
text = text.replace("\'s", " 'm")

text = text.strip()

if text.startswith('"') or text.startswith("'"):
text = text[1:]
if text.endswith('"') or text.endswith("'"):
text = text[:-1]

text = text.strip()
return text
Loading

0 comments on commit 3189890

Please sign in to comment.