Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
342 changes: 342 additions & 0 deletions week4/community-contributions/Ayesha/week4_test_case_generator.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,342 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "4838a040",
"metadata": {},
"source": [
"Code Refractor to optimize messy code and test case generator. "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5470df06",
"metadata": {},
"outputs": [],
"source": [
"!pip install gradio transformers torch accelerate openai bitsandbytes sentencepiece dotenv"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import os\n",
"import torch\n",
"import gradio as gr\n",
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
"from openai import OpenAI\n",
"from dotenv import load_dotenv"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "055a7399",
"metadata": {},
"outputs": [],
"source": [
"# =========================\n",
"# 🔹 OpenRouter Setup\n",
"# =========================\n",
"\n",
"\n",
"#connecting to openrouter\n",
"load_dotenv(override=True)\n",
"OPENROUTER_API_KEY = os.getenv('OPENROUTER_API_KEY')\n",
"\n",
"# Check the key\n",
"\n",
"if not OPENROUTER_API_KEY:\n",
" print(\"No API key was found - please head over to the troubleshooting notebook in this folder to identify & fix!\")\n",
"elif not OPENROUTER_API_KEY.startswith(\"sk-or-v1\"):\n",
" print(\"An API key was found, but it doesn't start sk-proj-; please check you're using the right key - see troubleshooting notebook\")\n",
"elif OPENROUTER_API_KEY.strip() != OPENROUTER_API_KEY:\n",
" print(\"An API key was found, but it looks like it might have space or tab characters at the start or end - please remove them - see troubleshooting notebook\")\n",
"else:\n",
" print(\"API key found and looks good so far!\")\n",
"\n",
" \n",
"openrouter_client = OpenAI(\n",
" base_url=\"https://openrouter.ai/api/v1\",\n",
" api_key=OPENROUTER_API_KEY\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "da6c871a",
"metadata": {},
"outputs": [],
"source": [
"# =========================\n",
"# 🔹 Hugging Face Setup\n",
"# =========================\n",
"HG_TOKEN = os.getenv(\"HG-TOKEN\")\n",
"\n",
"if HG_TOKEN:\n",
" if HG_TOKEN.startswith(\"hf_\") and HG_TOKEN.strip() == HG_TOKEN:\n",
" print(\"Hugging Face token was found.\")\n",
" else:\n",
" print(\"Hugging Face token format is not correct.\")\n",
"else:\n",
" print(\"No Hugging Face token found. Some HF models may fail.\")\n",
"\n",
"\n",
"HF_MODELS = {\n",
" \"Phi-3 Mini\": \"microsoft/Phi-3-mini-4k-instruct\",\n",
" \"TinyLlama\": \"TinyLlama/TinyLlama-1.1B-Chat-v1.0\",\n",
"}\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d06a8051",
"metadata": {},
"outputs": [],
"source": [
"loaded_models = {}\n",
"\n",
"def load_hf_model(model_key):\n",
" if model_key in loaded_models:\n",
" return loaded_models[model_key]\n",
"\n",
" model_name = HF_MODELS[model_key]\n",
" print(f\"Loading Hugging Face model: {model_name}\")\n",
"\n",
" tokenizer = AutoTokenizer.from_pretrained(\n",
" model_name,\n",
" token=HG_TOKEN\n",
" )\n",
"\n",
" model = AutoModelForCausalLM.from_pretrained(\n",
" model_name,\n",
" token=HG_TOKEN,\n",
" torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,\n",
" device_map=\"auto\"\n",
" )\n",
"\n",
" loaded_models[model_key] = (tokenizer, model)\n",
" return tokenizer, model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "77378945",
"metadata": {},
"outputs": [],
"source": [
"#building the refractor prompt\n",
"def build_refactor_prompt(code):\n",
" return f\"\"\"\n",
"You are a senior software engineer.\n",
"\n",
"Refactor the following code to:\n",
"- Improve readability\n",
"- Add type hints (if Python)\n",
"- Optimize logic if possible\n",
"- Follow best practices\n",
"- Keep functionality identical\n",
"\n",
"Return ONLY valid JSON in this format:\n",
"\n",
"{{\n",
" \"refactored_code\": \"...\",\n",
" \"improvements\": [\"improvement 1\", \"improvement 2\"]\n",
"}}\n",
"\n",
"Code:\n",
"{code}\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "956d2a39",
"metadata": {},
"outputs": [],
"source": [
"#building the test case generator prompt\n",
"def build_test_prompt(code):\n",
" return f\"\"\"\n",
"You are a senior software engineer.\n",
"\n",
"Generate pytest unit tests for the following code.\n",
"Include:\n",
"- Normal cases\n",
"- Edge cases\n",
"- Failure cases if applicable\n",
"\n",
"Return ONLY valid JSON in this format:\n",
"\n",
"{{\n",
" \"unit_tests\": \"pytest code here\"\n",
"}}\n",
"\n",
"Code:\n",
"{code}\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "815c2274",
"metadata": {},
"outputs": [],
"source": [
"def generate_with_openrouter(prompt, model_name):\n",
" if not openrouter_client:\n",
" raise ValueError(\"OpenRouter API key not configured.\")\n",
"\n",
" response = openrouter_client.chat.completions.create(\n",
" model=model_name,\n",
" messages=[{\"role\": \"user\", \"content\": prompt}],\n",
" temperature=0.5,\n",
" )\n",
" return response.choices[0].message.content.strip()\n",
"\n",
"def generate_with_huggingface(prompt, model_key):\n",
" tokenizer, model = load_hf_model(model_key)\n",
"\n",
" inputs = tokenizer(prompt, return_tensors=\"pt\")\n",
" inputs = {k: v.to(model.device) for k, v in inputs.items()}\n",
"\n",
" output = model.generate(\n",
" **inputs,\n",
" max_new_tokens=1500,\n",
" temperature=0.5,\n",
" do_sample=True\n",
" )\n",
"\n",
" generated_text = tokenizer.decode(\n",
" output[0][inputs[\"input_ids\"].shape[1]:],\n",
" skip_special_tokens=True\n",
" )\n",
"\n",
" return generated_text.strip()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0cb05fda",
"metadata": {},
"outputs": [],
"source": [
"def validate_json(text):\n",
" try:\n",
" return json.loads(text), None\n",
" except Exception as e:\n",
" return None, str(e)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b5f983c7",
"metadata": {},
"outputs": [],
"source": [
"def run_assistant(task_type, code, model_source, model_choice):\n",
"\n",
" if task_type == \"Refactor\":\n",
" prompt = build_refactor_prompt(code)\n",
" else:\n",
" prompt = build_test_prompt(code)\n",
"\n",
" try:\n",
" if model_source == \"OpenRouter\":\n",
" raw_output = generate_with_openrouter(prompt, model_choice)\n",
" else:\n",
" raw_output = generate_with_huggingface(prompt, model_choice)\n",
"\n",
" data, error = validate_json(raw_output)\n",
"\n",
" if error:\n",
" return f\"JSON Error: {error}\", raw_output\n",
"\n",
" return \"Success\", json.dumps(data, indent=2)\n",
"\n",
" except Exception as e:\n",
" return f\"Error: {str(e)}\", None"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "648b584f",
"metadata": {},
"outputs": [],
"source": [
"with gr.Blocks() as demo:\n",
" gr.Markdown(\"## Code Refactoring & Test Generator Assistant\")\n",
"\n",
" task_type = gr.Radio(\n",
" [\"Refactor\", \"Generate Tests\"],\n",
" value=\"Refactor\",\n",
" label=\"Task Type\"\n",
" )\n",
"\n",
" code_input = gr.Textbox(\n",
" label=\"Paste Code Here\",\n",
" lines=15\n",
" )\n",
"\n",
" model_source = gr.Radio(\n",
" [\"OpenRouter\", \"Hugging Face\"],\n",
" value=\"Hugging Face\",\n",
" label=\"Model Source\"\n",
" )\n",
"\n",
" model_choice = gr.Dropdown(\n",
" choices=[\"Phi-3 Mini\", \"TinyLlama\", \"openai/gpt-4o-mini\"],\n",
" value=\"Phi-3 Mini\",\n",
" label=\"Model\"\n",
" )\n",
"\n",
" status = gr.Textbox(label=\"Status\")\n",
" output_box = gr.Textbox(label=\"Output\", lines=20)\n",
"\n",
" run_btn = gr.Button(\"Run\")\n",
"\n",
" run_btn.click(\n",
" run_assistant,\n",
" inputs=[task_type, code_input, model_source, model_choice],\n",
" outputs=[status, output_box]\n",
" )\n",
"\n",
"demo.launch()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading