From 944566d9b8d9b14ad1f8039312a8c890682b03c5 Mon Sep 17 00:00:00 2001 From: copoer Date: Sun, 1 Jan 2023 14:11:07 -0400 Subject: [PATCH 1/2] Added small and large dialog gpt models --- src/gpt2/gpt2_model.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/gpt2/gpt2_model.rs b/src/gpt2/gpt2_model.rs index cbf442dcf..f964052c5 100644 --- a/src/gpt2/gpt2_model.rs +++ b/src/gpt2/gpt2_model.rs @@ -70,11 +70,21 @@ impl Gpt2ModelResources { "distilgpt2/model", "https://huggingface.co/distilgpt2/resolve/main/rust_model.ot", ); + /// Shared under MIT license by the Microsoft team at . Modified with conversion to C-array format. + pub const DIALOGPT_SMALL: (&'static str, &'static str) = ( + "dialogpt-small/model", + "https://huggingface.co/microsoft/DialoGPT-small/resolve/main/rust_model.ot", + ); /// Shared under MIT license by the Microsoft team at . Modified with conversion to C-array format. pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = ( "dialogpt-medium/model", "https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/rust_model.ot", ); + /// Shared under MIT license by the Microsoft team at . Modified with conversion to C-array format. + pub const DIALOGPT_LARGE: (&'static str, &'static str) = ( + "dialogpt-large/model", + "https://huggingface.co/microsoft/DialoGPT-large/resolve/main/rust_model.ot", + ); } impl Gpt2ConfigResources { From acf10708d7aca68ca68be7652a6eb6d6e5c2f881 Mon Sep 17 00:00:00 2001 From: copoer Date: Mon, 2 Jan 2023 13:09:26 -0400 Subject: [PATCH 2/2] Test (#1) * Testing model * Added model config, vocab, merges * Revert testing --- src/gpt2/gpt2_model.rs | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/src/gpt2/gpt2_model.rs b/src/gpt2/gpt2_model.rs index f964052c5..4a7c6c676 100644 --- a/src/gpt2/gpt2_model.rs +++ b/src/gpt2/gpt2_model.rs @@ -113,11 +113,23 @@ impl Gpt2ConfigResources { "distilgpt2/config", "https://huggingface.co/distilgpt2/resolve/main/config.json", ); + /// Shared under MIT license by the Microsoft team at . Modified with conversion to C-array format. + pub const DIALOGPT_SMALL: (&'static str, &'static str) = ( + "dialogpt-small/config", + "https://huggingface.co/microsoft/DialoGPT-small/resolve/main/config.json", + ); /// Shared under MIT license by the Microsoft team at . Modified with conversion to C-array format. pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = ( "dialogpt-medium/config", "https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/config.json", ); + /// Shared under MIT license by the Microsoft team at . Modified with conversion to C-array format. + pub const DIALOGPT_LARGE: (&'static str, &'static str) = ( + "dialogpt-large/config", + "https://huggingface.co/microsoft/DialoGPT-large/resolve/main/config.json", + ); + + } impl Gpt2VocabResources { @@ -146,11 +158,21 @@ impl Gpt2VocabResources { "distilgpt2/vocab", "https://huggingface.co/distilgpt2/resolve/main/vocab.json", ); + /// Shared under MIT license by the Microsoft team at . Modified with conversion to C-array format. + pub const DIALOGPT_SMALL: (&'static str, &'static str) = ( + "dialogpt-small/vocab", + "https://huggingface.co/microsoft/DialoGPT-small/resolve/main/vocab.json", + ); /// Shared under MIT license by the Microsoft team at . Modified with conversion to C-array format. pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = ( "dialogpt-medium/vocab", "https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/vocab.json", ); + /// Shared under MIT license by the Microsoft team at . Modified with conversion to C-array format. + pub const DIALOGPT_LARGE: (&'static str, &'static str) = ( + "dialogpt-large/vocab", + "https://huggingface.co/microsoft/DialoGPT-large/resolve/main/vocab.json", + ); } impl Gpt2MergesResources { @@ -179,11 +201,22 @@ impl Gpt2MergesResources { "distilgpt2/merges", "https://huggingface.co/distilgpt2/resolve/main/merges.txt", ); + /// Shared under MIT license by the Microsoft team at . Modified with conversion to C-array format. + pub const DIALOGPT_SMALL: (&'static str, &'static str) = ( + "dialogpt-small/merges", + "https://huggingface.co/microsoft/DialoGPT-small/resolve/main/merges.txt", + ); /// Shared under MIT license by the Microsoft team at . Modified with conversion to C-array format. pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = ( "dialogpt-medium/merges", "https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/merges.txt", ); + /// Shared under MIT license by the Microsoft team at . Modified with conversion to C-array format. + pub const DIALOGPT_LARGE: (&'static str, &'static str) = ( + "dialogpt-large/merges", + "https://huggingface.co/microsoft/DialoGPT-large/resolve/main/merges.txt", + ); + } #[derive(Debug, Serialize, Deserialize, Clone)]