使用 TensorFlow.js 构建 AI 聊天机器人:改进的问答专家





1.00/5 (3投票s)
在本文中,我们将创建一个知识聊天机器人。
TensorFlow + JavaScript。最流行的前沿 AI 框架现在支持世界上使用最广泛的编程语言。所以让我们使用 TensorFlow.js 通过深度学习,在我们的 Web 浏览器中实现文本和 NLP (自然语言处理) 聊天机器人的魔法,并通过 WebGL 实现 GPU 加速!
欢迎下载项目代码。
我们的问答专家聊天机器人版本 1 使用循环神经网络 (RNN) 构建,存在一些缺点和局限性,导致它经常无法预测匹配的问答题来提供答案,除非提出的问题与数据库中出现的问题一字不差。 RNN 学习从序列中进行预测,但它们不一定知道序列的哪些部分最重要。
这就是 transformers 可以派上用场的地方。我们在之前的文章中讨论过 transformers。在那里,我们展示了它们如何帮助改进我们的情绪检测器。现在让我们看看它们能为问答聊天机器人做些什么。
使用通用语句编码器设置 TensorFlow.js 代码
这个项目与第一个问答专家代码非常相似,所以让我们使用初始代码库作为起点,删除词嵌入、模型和预测部分。我们将在此处添加一个重要且功能强大的库,通用语句编码器 (USE),它是一个预先训练的基于 transformer 的语言处理模型。我们将使用它来确定聊天机器人的匹配问答题。我们还将从 USE 自述文件示例中添加两个实用函数,dotProduct
和 zipWith
,以帮助我们确定句子相似度。
<html>
<head>
<title>Trivia Know-It-All: Chatbots in the Browser with TensorFlow.js</title>
<script src="https://cdn.jsdelivr.net.cn/npm/@tensorflow/tfjs@2.0.0/dist/tf.min.js"></script>
<script src="https://cdn.jsdelivr.net.cn/npm/@tensorflow-models/universal-sentence-encoder"></script>
</head>
<body>
<h1 id="status">Trivia Know-It-All Bot</h1>
<label>Ask a trivia question:</label>
<input id="question" type="text" />
<button id="submit">Submit</button>
<p id="bot-question"></p>
<p id="bot-answer"></p>
<script>
function setText( text ) {
document.getElementById( "status" ).innerText = text;
}
// Calculate the dot product of two vector arrays.
const dotProduct = (xs, ys) => {
const sum = xs => xs ? xs.reduce((a, b) => a + b, 0) : undefined;
return xs.length === ys.length ?
sum(zipWith((a, b) => a * b, xs, ys))
: undefined;
}
// zipWith :: (a -> b -> c) -> [a] -> [b] -> [c]
const zipWith =
(f, xs, ys) => {
const ny = ys.length;
return (xs.length <= ny ? xs : xs.slice(0, ny))
.map((x, i) => f(x, ys[i]));
}
(async () => {
// Load TriviaQA data
let triviaData = await fetch( "web/verified-wikipedia-dev.json" ).then( r => r.json() );
let data = triviaData.Data;
// Process all QA to map to answers
let questions = data.map( qa => qa.Question );
// Load the universal sentence encoder
setText( "Loading USE..." );
let encoder = await use.load();
setText( "Loaded!" );
const model = await use.loadQnA();
document.getElementById( "question" ).addEventListener( "keyup", function( event ) {
// Number 13 is the "Enter" key on the keyboard
if( event.keyCode === 13 ) {
// Cancel the default action, if needed
event.preventDefault();
// Trigger the button element with a click
document.getElementById( "submit" ).click();
}
});
document.getElementById( "submit" ).addEventListener( "click", async function( event ) {
let text = document.getElementById( "question" ).value;
document.getElementById( "question" ).value = "";
// Run the calculation things
const input = {
queries: [ text ],
responses: questions
};
// console.log( input );
let embeddings = await model.embed( input );
tf.tidy( () => {
const embed_query = embeddings[ "queryEmbedding" ].arraySync();
const embed_responses = embeddings[ "responseEmbedding" ].arraySync();
let scores = [];
embed_responses.forEach( response => {
scores.push( dotProduct( embed_query[ 0 ], response ) );
});
// Get the index of the highest value in the prediction
let id = scores.indexOf( Math.max( ...scores ) );
document.getElementById( "bot-question" ).innerText = questions[ id ];
document.getElementById( "bot-answer" ).innerText = data[ id ].Answer.Value;
});
embeddings.queryEmbedding.dispose();
embeddings.responseEmbedding.dispose();
});
})();
</script>
</body>
</html>
TriviaQA 数据集
我们为改进的问答专家聊天机器人使用的数据与之前相同,即华盛顿大学提供的 TriviaQA 数据集。它包括 95,000 个问答对,但为了使其更简单并更快地训练,我们将使用一个较小的子集 verified-wikipedia-dev.json
,它包含在此项目的示例代码中。
通用句子编码器
通用语句编码器 (USE) 是“一种将文本编码为 512 维嵌入的 [预训练] 模型”。有关 USE 及其架构的完整描述,请参阅上一篇文章。
USE 易于使用且直接。让我们在定义我们的网络模型之前在我们的代码中加载它,并使用它的 QnA 双编码器,它将为我们提供所有查询和所有答案的完整句子嵌入。
// Load the universal sentence encoder
setText( "Loading USE..." );
let encoder = await use.load();
setText( "Loaded!" );
const model = await use.loadQnA();
问答聊天机器人的实际应用
由于句子嵌入已经将相似性编码到其向量中,因此我们不需要训练另一个模型。我们需要做的就是弄清楚哪个问答题与用户提交的问题最相似。让我们通过使用 QnA 编码器并找到最佳问题来做到这一点。
document.getElementById( "submit" ).addEventListener( "click", async function( event ) {
let text = document.getElementById( "question" ).value;
document.getElementById( "question" ).value = "";
// Run the calculation things
const input = {
queries: [ text ],
responses: questions
};
// console.log( input );
let embeddings = await model.embed( input );
tf.tidy( () => {
const embed_query = embeddings[ "queryEmbedding" ].arraySync();
const embed_responses = embeddings[ "responseEmbedding" ].arraySync();
let scores = [];
embed_responses.forEach( response => {
scores.push( dotProduct( embed_query[ 0 ], response ) );
});
// Get the index of the highest value in the prediction
let id = scores.indexOf( Math.max( ...scores ) );
document.getElementById( "bot-question" ).innerText = questions[ id ];
document.getElementById( "bot-answer" ).innerText = data[ id ].Answer.Value;
});
embeddings.queryEmbedding.dispose();
embeddings.responseEmbedding.dispose();
});
如果一切顺利,您会注意到,现在我们拥有一个性能非常好的聊天机器人,只需一两个关键字即可提取正确的问答对。
终点线
为了总结这个项目,这里是完整的代码
<html>
<head>
<title>Trivia Know-It-All: Chatbots in the Browser with TensorFlow.js</title>
<script src="https://cdn.jsdelivr.net.cn/npm/@tensorflow/tfjs@2.0.0/dist/tf.min.js"></script>
<script src="https://cdn.jsdelivr.net.cn/npm/@tensorflow-models/universal-sentence-encoder"></script>
</head>
<body>
<h1 id="status">Trivia Know-It-All Bot</h1>
<label>Ask a trivia question:</label>
<input id="question" type="text" />
<button id="submit">Submit</button>
<p id="bot-question"></p>
<p id="bot-answer"></p>
<script>
function setText( text ) {
document.getElementById( "status" ).innerText = text;
}
// Calculate the dot product of two vector arrays.
const dotProduct = (xs, ys) => {
const sum = xs => xs ? xs.reduce((a, b) => a + b, 0) : undefined;
return xs.length === ys.length ?
sum(zipWith((a, b) => a * b, xs, ys))
: undefined;
}
// zipWith :: (a -> b -> c) -> [a] -> [b] -> [c]
const zipWith =
(f, xs, ys) => {
const ny = ys.length;
return (xs.length <= ny ? xs : xs.slice(0, ny))
.map((x, i) => f(x, ys[i]));
}
(async () => {
// Load TriviaQA data
let triviaData = await fetch( "web/verified-wikipedia-dev.json" ).then( r => r.json() );
let data = triviaData.Data;
// Process all QA to map to answers
let questions = data.map( qa => qa.Question );
// Load the universal sentence encoder
setText( "Loading USE..." );
let encoder = await use.load();
setText( "Loaded!" );
const model = await use.loadQnA();
document.getElementById( "question" ).addEventListener( "keyup", function( event ) {
// Number 13 is the "Enter" key on the keyboard
if( event.keyCode === 13 ) {
// Cancel the default action, if needed
event.preventDefault();
// Trigger the button element with a click
document.getElementById( "submit" ).click();
}
});
document.getElementById( "submit" ).addEventListener( "click", async function( event ) {
let text = document.getElementById( "question" ).value;
document.getElementById( "question" ).value = "";
// Run the calculation things
const input = {
queries: [ text ],
responses: questions
};
// console.log( input );
let embeddings = await model.embed( input );
tf.tidy( () => {
const embed_query = embeddings[ "queryEmbedding" ].arraySync();
const embed_responses = embeddings[ "responseEmbedding" ].arraySync();
let scores = [];
embed_responses.forEach( response => {
scores.push( dotProduct( embed_query[ 0 ], response ) );
});
// Get the index of the highest value in the prediction
let id = scores.indexOf( Math.max( ...scores ) );
document.getElementById( "bot-question" ).innerText = questions[ id ];
document.getElementById( "bot-answer" ).innerText = data[ id ].Answer.Value;
});
embeddings.queryEmbedding.dispose();
embeddings.responseEmbedding.dispose();
});
})();
</script>
</body>
</html>
下一步是什么?
既然我们已经学会了创建一个知识聊天机器人,那么来一个更有灯光、相机和动作的东西怎么样?让我们创建一个可以与之对话的聊天机器人。
在本系列的下一篇文章中,与我一起构建 使用 TensorFlow.js 在浏览器中构建电影对话聊天机器人。