import gradio as gr
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
from sklearn.cluster import KMeans
from collections import Counter

def embedding(sentences):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-large-zh')
    model = AutoModel.from_pretrained('BAAI/bge-large-zh').to(device)
    
    encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt').to(device)
    
    with torch.no_grad():
        model_output = model(**encoded_input)
    
    sentence_embeddings = model_output.last_hidden_state[:, 0]
    sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1).to(device)
    return sentence_embeddings

def cluster_sentences(sentences, cluster_num):
    X = embedding(sentences)
    
    kmeans = KMeans(n_clusters=cluster_num)
    kmeans.fit(X.cpu())
    labels = kmeans.labels_
    
    cluster_result = dict()
    for i in range(len(sentences)):
        if labels[i] not in cluster_result:
            cluster_result[labels[i]] = []
        cluster_result[labels[i]].append(sentences[i])
    
    # 构建HTML输出
    html = "<div>"
    for label, clustered_sentences in cluster_result.items():
        sentence_count = len(clustered_sentences)
        sentence_counter = Counter(clustered_sentences)
        duplicate_count = sum(count for count in sentence_counter.values() if count > 1)
        html += f"<h3>Cluster {label} - Total Sentences: {sentence_count}, Duplicates: {duplicate_count}</h3><ul>"
        for sentence in clustered_sentences:
            count = sentence_counter[sentence]
            html += f"<li>{sentence} (Count: {count})</li>"
        html += "</ul>"
    html += "</div>"
    
    return html

def main_interface(sentence_input, cluster_num):
    sentences = [sentence.strip() for sentence in sentence_input.split(',')]
    clustered_sentences = cluster_sentences(sentences, cluster_num)
    return clustered_sentences

# 使用 Gradio 构建用户界面
iface = gr.Interface(
    fn=main_interface,
    inputs=[
        gr.inputs.Textbox(lines=5, placeholder="请输入句子列表,每个句子用逗号分隔", label="句子输入"),
        gr.inputs.Slider(minimum=2, maximum=30, step=1, default=15, label="KMeans 聚类数")
    ],
    outputs=gr.outputs.HTML(label="聚类结果"),
    title="句子嵌入和聚类",
    description="输入一组用逗号分隔的句子,选择聚类方法和相应参数,进行嵌入和聚类。",
    theme="compact"
)

# 启动 Gradio 应用
iface.launch()

点赞(0) 打赏

评论列表 共有 0 条评论

暂无评论

微信公众账号

微信扫一扫加关注

发表
评论
返回
顶部