import gradio as gr
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
from sklearn.cluster import KMeans
from collections import Counter
def embedding(sentences):
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-large-zh')
model = AutoModel.from_pretrained('BAAI/bge-large-zh').to(device)
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt').to(device)
with torch.no_grad():
model_output = model(**encoded_input)
sentence_embeddings = model_output.last_hidden_state[:, 0]
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1).to(device)
return sentence_embeddings
def cluster_sentences(sentences, cluster_num):
X = embedding(sentences)
kmeans = KMeans(n_clusters=cluster_num)
kmeans.fit(X.cpu())
labels = kmeans.labels_
cluster_result = dict()
for i in range(len(sentences)):
if labels[i] not in cluster_result:
cluster_result[labels[i]] = []
cluster_result[labels[i]].append(sentences[i])
# 构建HTML输出
html = "<div>"
for label, clustered_sentences in cluster_result.items():
sentence_count = len(clustered_sentences)
sentence_counter = Counter(clustered_sentences)
duplicate_count = sum(count for count in sentence_counter.values() if count > 1)
html += f"<h3>Cluster {label} - Total Sentences: {sentence_count}, Duplicates: {duplicate_count}</h3><ul>"
for sentence in clustered_sentences:
count = sentence_counter[sentence]
html += f"<li>{sentence} (Count: {count})</li>"
html += "</ul>"
html += "</div>"
return html
def main_interface(sentence_input, cluster_num):
sentences = [sentence.strip() for sentence in sentence_input.split(',')]
clustered_sentences = cluster_sentences(sentences, cluster_num)
return clustered_sentences
# 使用 Gradio 构建用户界面
iface = gr.Interface(
fn=main_interface,
inputs=[
gr.inputs.Textbox(lines=5, placeholder="请输入句子列表,每个句子用逗号分隔", label="句子输入"),
gr.inputs.Slider(minimum=2, maximum=30, step=1, default=15, label="KMeans 聚类数")
],
outputs=gr.outputs.HTML(label="聚类结果"),
title="句子嵌入和聚类",
description="输入一组用逗号分隔的句子,选择聚类方法和相应参数,进行嵌入和聚类。",
theme="compact"
)
# 启动 Gradio 应用
iface.launch()
本站资源均来自互联网,仅供研究学习,禁止违法使用和商用,产生法律纠纷本站概不负责!如果侵犯了您的权益请与我们联系!
转载请注明出处: 免费源码网-免费的源码资源网站 » gradio
发表评论 取消回复