Skip to content

Commit dcb72f1

Browse files
committed
add countToken support
add model support add models support add api version control support
1 parent d179f77 commit dcb72f1

File tree

8 files changed

+142
-25
lines changed

8 files changed

+142
-25
lines changed

README.md

+23-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
gemini.api-key={your gemini api key}
2020
gemini.proxy-host={your http proxy host}
2121
gemini.proxy-port={your http proxy port}
22+
gemini.version={gemini api version, default is v1beta}
2223
```
2324

2425
* Use GeminiClient in your code
@@ -28,6 +29,8 @@ package com.codingapi.gemini.client;
2829

2930
import com.codingapi.gemini.pojo.Embedding;
3031
import com.codingapi.gemini.pojo.Generate;
32+
import com.codingapi.gemini.pojo.Model;
33+
import com.codingapi.gemini.pojo.Models;
3134
import lombok.SneakyThrows;
3235
import org.junit.jupiter.api.Test;
3336
import org.springframework.beans.factory.annotation.Autowired;
@@ -51,7 +54,6 @@ class GeminiClientTest {
5154
System.out.println(answer);
5255
}
5356

54-
5557
@Test
5658
void generateConfiguration() {
5759
Generate.Request request = Generate.creatTextChart("你好,请用中文简体回答我,你如何看待区块链?");
@@ -80,13 +82,33 @@ class GeminiClientTest {
8082
System.out.println(answer);
8183
}
8284

85+
@Test
86+
void counts() {
87+
Generate.Request request = Generate.creatTextChart("你好,请用中文简体回答我,你如何看待区块链?");
88+
int tokens = client.counts(request);
89+
System.out.println(tokens);
90+
assert tokens > 0;
91+
}
92+
8393
@Test
8494
void embedding() {
8595
Embedding.Request request = Embedding.creat("你好,我是小强");
8696
Embedding.Response response = client.embedding(request);
8797
List<Double> answer = Embedding.toAnswer(response);
8898
System.out.println(answer);
8999
}
100+
101+
@Test
102+
void model() {
103+
Model model = client.model("models/gemini-pro");
104+
System.out.println(model);
105+
}
106+
107+
@Test
108+
void models() {
109+
Models models = client.models();
110+
System.out.println(models);
111+
}
90112
}
91113
```
92114

src/main/java/com/codingapi/gemini/GeminiConfiguration.java

+7-3
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,17 @@ public class GeminiConfiguration {
1111

1212
@Bean
1313
@ConfigurationProperties(prefix = "gemini")
14-
public GeminiProperties geminiProperties(){
14+
public GeminiProperties geminiProperties() {
1515
return new GeminiProperties();
1616
}
1717

1818
@Bean
19-
public GeminiClient geminiClient(GeminiProperties geminiProperties){
20-
return new GeminiClient(geminiProperties.getApiKey(), geminiProperties.getProxyHost(), geminiProperties.getProxyPort());
19+
public GeminiClient geminiClient(GeminiProperties geminiProperties) {
20+
return new GeminiClient(
21+
geminiProperties.getVersion(),
22+
geminiProperties.getApiKey(),
23+
geminiProperties.getProxyHost(),
24+
geminiProperties.getProxyPort());
2125
}
2226

2327

src/main/java/com/codingapi/gemini/client/GeminiClient.java

+38-12
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import com.alibaba.fastjson.JSONObject;
55
import com.codingapi.gemini.pojo.Embedding;
66
import com.codingapi.gemini.pojo.Generate;
7+
import com.codingapi.gemini.pojo.Model;
8+
import com.codingapi.gemini.pojo.Models;
79
import lombok.extern.slf4j.Slf4j;
810
import org.springframework.core.io.Resource;
911
import org.springframework.http.HttpEntity;
@@ -27,13 +29,13 @@ public class GeminiClient {
2729

2830
private final RestTemplate restTemplate;
2931
private final String apiKey;
30-
31-
private final static String baseUrl = "https://generativelanguage.googleapis.com/v1beta/";
32-
32+
private final String baseUrl;
3333
private final HttpHeaders headers;
3434

35-
public GeminiClient(String apiKey, String proxyHost, int proxyPort) {
35+
36+
public GeminiClient(String version, String apiKey, String proxyHost, int proxyPort) {
3637
this.apiKey = apiKey;
38+
this.baseUrl = "https://generativelanguage.googleapis.com/" + version + "/";
3739
this.restTemplate = new RestTemplate();
3840

3941
this.headers = new HttpHeaders();
@@ -48,7 +50,7 @@ public GeminiClient(String apiKey, String proxyHost, int proxyPort) {
4850
}
4951

5052
public void stream(Generate.Request request, Consumer<Generate.Response> consumer) throws IOException {
51-
String url = baseUrl + "models/gemini-pro:streamGenerateContent?key=" + apiKey;
53+
String url = baseUrl + request.getModel() + ":streamGenerateContent?key=" + apiKey;
5254
String json = request.toJSONString();
5355
log.info("json:{}", json);
5456
HttpEntity<String> httpEntity = new HttpEntity<>(json, headers);
@@ -68,26 +70,50 @@ public void stream(Generate.Request request, Consumer<Generate.Response> consume
6870

6971

7072
public Generate.Response generate(Generate.Request request) {
71-
String url;
72-
if (request.isVision()) {
73-
url = baseUrl + "models/gemini-pro-vision:generateContent?key=" + apiKey;
74-
} else {
75-
url = baseUrl + "models/gemini-pro:generateContent?key=" + apiKey;
76-
}
73+
String url = baseUrl + request.getModel() + ":generateContent?key=" + apiKey;
7774
String json = request.toJSONString();
7875
log.info("json:{}", json);
7976
HttpEntity<String> httpEntity = new HttpEntity<>(json, headers);
8077
ResponseEntity<String> response = restTemplate.exchange(url, HttpMethod.POST, httpEntity, String.class);
8178
return JSONObject.parseObject(response.getBody(), Generate.Response.class);
8279
}
8380

81+
82+
public int counts(Generate.Request request) {
83+
String url = baseUrl + request.getModel() + ":countTokens?key=" + apiKey;
84+
String json = request.toJSONString();
85+
log.info("json:{}", json);
86+
HttpEntity<String> httpEntity = new HttpEntity<>(json, headers);
87+
ResponseEntity<String> response = restTemplate.exchange(url, HttpMethod.POST, httpEntity, String.class);
88+
Generate.TotalToken result = JSONObject.parseObject(response.getBody(), Generate.TotalToken.class);
89+
assert result != null;
90+
return result.getTotalTokens();
91+
}
92+
93+
8494
public Embedding.Response embedding(Embedding.Request request) {
85-
String url = baseUrl + "models//embedding-001:embedContent?key=" + apiKey;
95+
String url = baseUrl + request.getModel() + ":embedContent?key=" + apiKey;
8696
String json = request.toJSONString();
8797
log.info("json:{}", json);
8898
HttpEntity<String> httpEntity = new HttpEntity<>(json, headers);
8999
ResponseEntity<String> response = restTemplate.exchange(url, HttpMethod.POST, httpEntity, String.class);
90100
return JSONObject.parseObject(response.getBody(), Embedding.Response.class);
91101
}
92102

103+
104+
public Model model(String model) {
105+
String url = baseUrl + model + "?key=" + apiKey;
106+
HttpEntity<String> httpEntity = new HttpEntity<>(headers);
107+
ResponseEntity<String> response = restTemplate.exchange(url, HttpMethod.GET, httpEntity, String.class);
108+
return JSONObject.parseObject(response.getBody(), Model.class);
109+
}
110+
111+
public Models models() {
112+
String url = baseUrl + "models" + "?key=" + apiKey;
113+
HttpEntity<String> httpEntity = new HttpEntity<>(headers);
114+
ResponseEntity<String> response = restTemplate.exchange(url, HttpMethod.GET, httpEntity, String.class);
115+
return JSONObject.parseObject(response.getBody(), Models.class);
116+
}
117+
118+
93119
}

src/main/java/com/codingapi/gemini/pojo/Generate.java

+11-8
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,20 @@
1717
public class Generate {
1818

1919
public static Request creatTextChart(String text) {
20-
Request request = new Request();
20+
Request request = new Request("models/gemini-pro");
2121
Request.Chat chat = new Request.Chat();
2222
chat.setRole("user");
2323
chat.getParts().add(new Request.TextPart(text));
2424
request.getContents().add(chat);
25-
request.vision = false;
2625
return request;
2726
}
2827

2928
public static Request creatImageChart(String text, File image) throws IOException {
30-
Request request = new Request();
29+
Request request = new Request("models/gemini-pro-vision");
3130
Request.Chat chat = new Request.Chat();
3231
chat.getParts().add(new Request.TextPart(text));
3332
chat.getParts().add(new Request.ImagePart(image));
3433
request.getContents().add(chat);
35-
request.vision = true;
3634
return request;
3735
}
3836

@@ -83,12 +81,13 @@ public static class Request {
8381
private List<Chat> contents;
8482

8583
@JSONField(serialize = false)
86-
private boolean vision;
84+
private String model;
8785

8886
private List<SafetySetting> safetySettings;
8987
private GenerationConfig generationConfig;
9088

91-
public Request() {
89+
public Request(String model) {
90+
this.model = model;
9291
this.contents = new ArrayList<>();
9392
}
9493

@@ -159,11 +158,15 @@ public static class Chat {
159158
public Chat() {
160159
this.parts = new ArrayList<>();
161160
}
162-
163-
164161
}
165162
}
166163

164+
@Setter
165+
@Getter
166+
public static class TotalToken{
167+
private int totalTokens;
168+
}
169+
167170
@Setter
168171
@Getter
169172
public static class Response {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
package com.codingapi.gemini.pojo;
2+
3+
import lombok.Getter;
4+
import lombok.Setter;
5+
import lombok.ToString;
6+
7+
import java.util.List;
8+
9+
@Setter
10+
@Getter
11+
@ToString
12+
public class Model {
13+
14+
private String name;
15+
private String version;
16+
private String displayName;
17+
private String description;
18+
private int inputTokenLimit;
19+
private int outputTokenLimit;
20+
private List<String> supportedGenerationMethods;
21+
private float temperature;
22+
private float topP;
23+
private float topK;
24+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package com.codingapi.gemini.pojo;
2+
3+
import lombok.Getter;
4+
import lombok.Setter;
5+
import lombok.ToString;
6+
7+
import java.util.List;
8+
9+
@Setter
10+
@Getter
11+
@ToString
12+
public class Models {
13+
14+
private List<Model> models;
15+
}

src/main/java/com/codingapi/gemini/properties/GeminiProperties.java

+2
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,6 @@ public class GeminiProperties {
1010
private String apiKey;
1111
private String proxyHost;
1212
private int proxyPort;
13+
14+
private String version = "v1beta";
1315
}

src/test/java/com/codingapi/gemini/client/GeminiClientTest.java

+22-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import com.codingapi.gemini.pojo.Embedding;
44
import com.codingapi.gemini.pojo.Generate;
5+
import com.codingapi.gemini.pojo.Model;
6+
import com.codingapi.gemini.pojo.Models;
57
import lombok.SneakyThrows;
68
import org.junit.jupiter.api.Test;
79
import org.springframework.beans.factory.annotation.Autowired;
@@ -25,7 +27,6 @@ void generate() {
2527
System.out.println(answer);
2628
}
2729

28-
2930
@Test
3031
void generateConfiguration() {
3132
Generate.Request request = Generate.creatTextChart("你好,请用中文简体回答我,你如何看待区块链?");
@@ -54,11 +55,31 @@ void generateVision() throws IOException {
5455
System.out.println(answer);
5556
}
5657

58+
@Test
59+
void counts() {
60+
Generate.Request request = Generate.creatTextChart("你好,请用中文简体回答我,你如何看待区块链?");
61+
int tokens = client.counts(request);
62+
System.out.println(tokens);
63+
assert tokens > 0;
64+
}
65+
5766
@Test
5867
void embedding() {
5968
Embedding.Request request = Embedding.creat("你好,我是小强");
6069
Embedding.Response response = client.embedding(request);
6170
List<Double> answer = Embedding.toAnswer(response);
6271
System.out.println(answer);
6372
}
73+
74+
@Test
75+
void model() {
76+
Model model = client.model("models/gemini-pro");
77+
System.out.println(model);
78+
}
79+
80+
@Test
81+
void models() {
82+
Models models = client.models();
83+
System.out.println(models);
84+
}
6485
}

0 commit comments

Comments
 (0)