// Handles the heatmap event, group similarity logic, and text reordering for cluster visualization import { getGroupEmbeddings, getLineEmbeddings } from './embedding.js'; import { plotHeatmap } from './plotting.js'; const task = "Given a textual input sentence, retrieve relevant categories that best describe it."; // Cosine similarity between two vectors function cosine(a, b) { let dot = 0, na = 0, nb = 0; for (let i = 0; i < a.length; i++) { dot += a[i] * b[i]; na += a[i] * a[i]; nb += b[i] * b[i]; } return na && nb ? dot / Math.sqrt(na * nb) : 0; } // Remove group headers and split each group into an array of lines (excluding empty lines and headers) function getCleanGroups(groups) { return groups.map(g => g.split("\n").filter(l => l && !l.startsWith("##")) ); } // Flatten all lines from all groups and get their embeddings async function getAllLinesAndEmbeds(cleanGroups, task) { const allLines = cleanGroups.flat(); const allEmbeds = await getLineEmbeddings(allLines, task); return { allLines, allEmbeds }; } // Build an index mapping for each group to map group-relative indices to global indices function getIdxByGroup(cleanGroups) { const idxByGroup = []; let p = 0; for (const g of cleanGroups) { idxByGroup.push(Array.from({ length: g.length }, (_, i) => p + i)); p += g.length; } return idxByGroup; } // Build the final output text for reordered groups, including headers and sorted lines function buildFinalText(order, sortedLines, clusterNames, n) { return order .map((gIdx, i) => { const header = clusterNames?.length === n ? clusterNames[gIdx] : `Group ${i + 1}`; return `## ${header}\n${sortedLines[i].join("\n")}`; }) .join("\n\n\n"); } export async function handleHeatmapEvent() { const progressBar = document.getElementById("progress-bar"); const progressBarInner = document.getElementById("progress-bar-inner"); progressBar.style.display = "block"; progressBarInner.style.width = "0%"; const text = document.getElementById("input").value; // Get search group from dedicated input (do not use ##search in main input) const searchGroupText = document.getElementById("search-group-input")?.value.trim(); // Get search sort mode from dropdown (either 'line' or 'group') const searchSortMode = document.getElementById("search-sort-mode")?.value || "group"; const search_by_max_search_line = searchSortMode === "line"; const search_by_max_search_group = searchSortMode === "group"; // Parse cluster names from main input (ignore any ##search) const clusterNames = text.split(/\n/) .map(x => x.trim()) .filter(x => x && x.startsWith('##')) .map(x => x.replace(/^##\s*/, '')); const groups = text.split(/\n{3,}/); // Get group embeddings (removes ## lines internally) const groupEmbeddings = await getGroupEmbeddings(groups, task); const n = groupEmbeddings.length; progressBarInner.style.width = "30%"; // Compute cosine similarity matrix between all group embeddings const sim = []; for (let i = 0; i < n; i++) { const row = []; for (let j = 0; j < n; j++) { let dot = 0, na = 0, nb = 0; for (let k = 0; k < groupEmbeddings[i].length; k++) { dot += groupEmbeddings[i][k] * groupEmbeddings[j][k]; na += groupEmbeddings[i][k] ** 2; nb += groupEmbeddings[j][k] ** 2; } row.push(dot / Math.sqrt(na * nb)); } sim.push(row); } progressBarInner.style.width = "60%"; // Always use all group indices in order let order = Array.from({ length: n }, (_, i) => i); // Only use search group if provided in search-group-input let useSearchGroup = !!searchGroupText; let searchIdx = -1; let searchLines = []; let searchEmbeds = []; let refEmbed = null; if (useSearchGroup) { searchLines = searchGroupText.split(/\n/).map(l => l.trim()).filter(l => l); if (searchLines.length > 0) { searchEmbeds = await getLineEmbeddings(searchLines, task); // For group similarity, use the mean embedding of the search group refEmbed = searchEmbeds[0].map((_, i) => searchEmbeds.reduce((sum, e) => sum + e[i], 0) / searchEmbeds.length); // Compute similarity to each group const simToSearch = groupEmbeddings.map((emb, i) => ({ idx: i, sim: cosine(refEmbed, emb) })); simToSearch.sort((a, b) => b.sim - a.sim); order = [/* search group is not in groupEmbeddings, so just prepend -1 for heatmap */ -1, ...simToSearch.map(x => x.idx)]; } } // Reorder sim matrix and clusterNames for heatmap visualization let simOrdered, xLabels; if (useSearchGroup && searchLines.length > 0) { // Insert search group as first row/col in heatmap, with similarity 1 to itself and to other groups simOrdered = [ [1, ...order.slice(1).map(idx => idx === -1 ? 1 : cosine(refEmbed, groupEmbeddings[idx]))], ...order.slice(1).map(i => [cosine(refEmbed, groupEmbeddings[i]), ...order.slice(1).map(j => sim[i][j])]) ]; xLabels = ["Search", ...order.slice(1).map(i => (clusterNames && clusterNames[i]) ? clusterNames[i] : `Group ${i + 1}`)]; } else { simOrdered = order.map(i => order.map(j => sim[i][j])); xLabels = order.map(i => (clusterNames && clusterNames[i]) ? clusterNames[i] : `Group ${i + 1}`); } // If search group is provided and sorting by line, reorder lines in each group by similarity to search lines if (useSearchGroup && search_by_max_search_line && searchEmbeds.length > 0) { const cleanGroups = getCleanGroups(groups); const { allLines, allEmbeds } = await getAllLinesAndEmbeds(cleanGroups, task); const idxByGroup = getIdxByGroup(cleanGroups); const score = e => Math.max(...searchEmbeds.map(se => cosine(se, e))); // Skip -1 (search group) in order for main input reordering const sorted = (order[0] === -1 ? order.slice(1) : order).map(g => idxByGroup[g] .map(i => ({ t: allLines[i], s: score(allEmbeds[i]) })) .sort((a, b) => b.s - a.s) .map(o => o.t) ); const finalText = buildFinalText(order[0] === -1 ? order.slice(1) : order, sorted, clusterNames, n); document.getElementById("input").value = finalText; } // If search group is provided and sorting by group, reorder lines in each group by similarity to the search group embedding if (useSearchGroup && search_by_max_search_group && refEmbed) { const cleanGroups = getCleanGroups(groups); const { allLines, allEmbeds } = await getAllLinesAndEmbeds(cleanGroups, task); const idxByGroup = getIdxByGroup(cleanGroups); // Skip -1 (search group) in order for main input reordering const sortedLines = (order[0] === -1 ? order.slice(1) : order).map(gIdx => idxByGroup[gIdx] .map(i => ({ t: allLines[i], s: cosine(refEmbed, allEmbeds[i]) })) .sort((a, b) => b.s - a.s) .map(o => o.t) ); const finalText = buildFinalText(order[0] === -1 ? order.slice(1) : order, sortedLines, clusterNames, n); document.getElementById("input").value = finalText; } plotHeatmap(simOrdered, xLabels, xLabels); progressBarInner.style.width = "100%"; }