Skip to content

Commit

Permalink
Always require credentials to download databases from github
Browse files Browse the repository at this point in the history
  • Loading branch information
robertbrignull committed Feb 26, 2025
1 parent 82e98c5 commit b855574
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 216 deletions.
10 changes: 1 addition & 9 deletions extensions/ql-vscode/src/databases/database-fetcher.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,8 @@ import {
allowHttp,
downloadTimeout,
getGitHubInstanceUrl,
hasGhecDrUri,
isCanary,
} from "../config";
import { showAndLogInformationMessage } from "../common/logging";
import { AppOctokit } from "../common/octokit";
import type { DatabaseOrigin } from "./local-databases/database-origin";
import { createTimeoutSignal } from "../common/fetch-stream";
import type { App } from "../common/app";
Expand Down Expand Up @@ -187,12 +184,7 @@ export class DatabaseFetcher {
throw new Error(`Invalid GitHub repository: ${githubRepo}`);
}

const credentials =
isCanary() || hasGhecDrUri() ? this.app.credentials : undefined;

const octokit = credentials
? await credentials.getOctokit()
: new AppOctokit();
const octokit = await this.app.credentials.getOctokit();

const result = await convertGithubNwoToDatabaseUrl(
nwo,
Expand Down
56 changes: 11 additions & 45 deletions extensions/ql-vscode/src/databases/github-databases/api.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import { RequestError } from "@octokit/request-error";
import type { Octokit } from "@octokit/rest";
import type { RestEndpointMethodTypes } from "@octokit/plugin-rest-endpoint-methods";
import { showNeverAskAgainDialog } from "../../common/vscode/dialog";
import type { GitHubDatabaseConfig } from "../../config";
import { hasGhecDrUri } from "../../config";
import type { Credentials } from "../../common/authentication";
import { AppOctokit } from "../../common/octokit";
import type { ProgressCallback } from "../../common/vscode/progress";
import { getErrorMessage } from "../../common/helpers-pure";
import { getLanguageDisplayName } from "../../common/query-language";
Expand Down Expand Up @@ -68,52 +65,21 @@ export async function listDatabases(
credentials: Credentials,
config: GitHubDatabaseConfig,
): Promise<ListDatabasesResult | undefined> {
// On GHEC-DR, unauthenticated requests will never work, so we should always ask
// for authentication.
const hasAccessToken =
!!(await credentials.getExistingAccessToken()) || hasGhecDrUri();
const hasAccessToken = !!(await credentials.getExistingAccessToken());

let octokit = hasAccessToken
? await credentials.getOctokit()
: new AppOctokit();

let promptedForCredentials = false;

let databases: CodeqlDatabase[];
try {
const response = await octokit.rest.codeScanning.listCodeqlDatabases({
owner,
repo,
});
databases = response.data;
} catch (e) {
// If we get a 404 when we don't have an access token, it might be because
// the repository is private/internal. Therefore, we should ask the user
// whether they want to connect to GitHub and try again.
if (e instanceof RequestError && e.status === 404 && !hasAccessToken) {
// Check whether the user wants to connect to GitHub
if (!(await askForGitHubConnect(config))) {
return;
}

// Prompt for credentials
octokit = await credentials.getOctokit();

promptedForCredentials = true;

const response = await octokit.rest.codeScanning.listCodeqlDatabases({
owner,
repo,
});
databases = response.data;
} else {
throw e;
}
if (!hasAccessToken && !(await askForGitHubConnect(config))) {
return undefined;
}
const octokit = await credentials.getOctokit();

const response = await octokit.rest.codeScanning.listCodeqlDatabases({
owner,
repo,
});

return {
promptedForCredentials,
databases,
promptedForCredentials: !hasAccessToken,
databases: response.data,
octokit,
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,9 @@ import {
} from "../../../../../src/databases/github-databases/api";
import type { Credentials } from "../../../../../src/common/authentication";
import type { Octokit } from "@octokit/rest";
import { AppOctokit } from "../../../../../src/common/octokit";
import { RequestError } from "@octokit/request-error";
import { window } from "vscode";

// Mock the AppOctokit constructor to ensure we aren't making any network requests
jest.mock("../../../../../src/common/octokit", () => ({
AppOctokit: jest.fn(),
}));
const appMockListCodeqlDatabases = mockedOctokitFunction<
"codeScanning",
"listCodeqlDatabases"
>();
const appOctokit = mockedObject<Octokit>({
rest: {
codeScanning: {
listCodeqlDatabases: appMockListCodeqlDatabases,
},
},
});
beforeEach(() => {
(AppOctokit as unknown as jest.Mock).mockImplementation(() => appOctokit);
});

describe("listDatabases", () => {
const owner = "github";
const repo = "codeql";
Expand Down Expand Up @@ -161,29 +141,59 @@ describe("listDatabases", () => {
});

describe("when the user does not have an access token", () => {
describe("when the repo is public", () => {
beforeEach(() => {
credentials = mockedObject<Credentials>({
getExistingAccessToken: () => undefined,
});
beforeEach(() => {
credentials = mockedObject<Credentials>({
getExistingAccessToken: () => undefined,
getOctokit: () => octokit,
});
});

mockListCodeqlDatabases.mockResolvedValue(undefined);
appMockListCodeqlDatabases.mockResolvedValue(successfulMockApiResponse);
describe("when answering connect to prompt", () => {
beforeEach(() => {
showNeverAskAgainDialogSpy.mockResolvedValue("Connect");
});

it("returns the databases", async () => {
const result = await listDatabases(owner, repo, credentials, config);
expect(result).toEqual({
databases,
promptedForCredentials: false,
octokit: appOctokit,
promptedForCredentials: true,
octokit,
});
expect(showNeverAskAgainDialogSpy).toHaveBeenCalled();
expect(mockListCodeqlDatabases).toHaveBeenCalled();
});

describe("when the request fails with a 404", () => {
beforeEach(() => {
mockListCodeqlDatabases.mockRejectedValue(
new RequestError("Not found", 404, {
request: {
method: "GET",
url: "",
headers: {},
},
response: {
status: 404,
headers: {},
url: "",
data: {},
retryCount: 0,
},
}),
);
});

it("throws an error", async () => {
await expect(
listDatabases(owner, repo, credentials, config),
).rejects.toThrow("Not found");
});
expect(showNeverAskAgainDialogSpy).not.toHaveBeenCalled();
});

describe("when the request fails with a 500", () => {
beforeEach(() => {
appMockListCodeqlDatabases.mockRejectedValue(
mockListCodeqlDatabases.mockRejectedValue(
new RequestError("Internal server error", 500, {
request: {
method: "GET",
Expand All @@ -205,151 +215,49 @@ describe("listDatabases", () => {
await expect(
listDatabases(owner, repo, credentials, config),
).rejects.toThrow("Internal server error");
expect(mockListCodeqlDatabases).not.toHaveBeenCalled();
});
});
});

describe("when the repo is private", () => {
describe("when cancelling prompt", () => {
beforeEach(() => {
credentials = mockedObject<Credentials>({
getExistingAccessToken: () => undefined,
getOctokit: () => octokit,
});

appMockListCodeqlDatabases.mockRejectedValue(
new RequestError("Not found", 404, {
request: {
method: "GET",
url: "",
headers: {},
},
response: {
status: 404,
headers: {},
url: "",
data: {},
retryCount: 0,
},
}),
);
showNeverAskAgainDialogSpy.mockResolvedValue(undefined);
});

describe("when answering connect to prompt", () => {
beforeEach(() => {
showNeverAskAgainDialogSpy.mockResolvedValue("Connect");
});

it("returns the databases", async () => {
const result = await listDatabases(owner, repo, credentials, config);
expect(result).toEqual({
databases,
promptedForCredentials: true,
octokit,
});
expect(showNeverAskAgainDialogSpy).toHaveBeenCalled();
expect(appMockListCodeqlDatabases).toHaveBeenCalled();
expect(mockListCodeqlDatabases).toHaveBeenCalled();
});

describe("when the request fails with a 404", () => {
beforeEach(() => {
mockListCodeqlDatabases.mockRejectedValue(
new RequestError("Not found", 404, {
request: {
method: "GET",
url: "",
headers: {},
},
response: {
status: 404,
headers: {},
url: "",
data: {},
retryCount: 0,
},
}),
);
});

it("throws an error", async () => {
await expect(
listDatabases(owner, repo, credentials, config),
).rejects.toThrow("Not found");
});
});

describe("when the request fails with a 500", () => {
beforeEach(() => {
mockListCodeqlDatabases.mockRejectedValue(
new RequestError("Internal server error", 500, {
request: {
method: "GET",
url: "",
headers: {},
},
response: {
status: 500,
headers: {},
url: "",
data: {},
retryCount: 0,
},
}),
);
});

it("throws an error", async () => {
await expect(
listDatabases(owner, repo, credentials, config),
).rejects.toThrow("Internal server error");
});
});
it("returns undefined", async () => {
const result = await listDatabases(owner, repo, credentials, config);
expect(result).toEqual(undefined);
expect(showNeverAskAgainDialogSpy).toHaveBeenCalled();
expect(mockListCodeqlDatabases).not.toHaveBeenCalled();
expect(setDownload).not.toHaveBeenCalled();
});
});

describe("when cancelling prompt", () => {
beforeEach(() => {
showNeverAskAgainDialogSpy.mockResolvedValue(undefined);
});

it("returns undefined", async () => {
const result = await listDatabases(owner, repo, credentials, config);
expect(result).toEqual(undefined);
expect(showNeverAskAgainDialogSpy).toHaveBeenCalled();
expect(appMockListCodeqlDatabases).toHaveBeenCalled();
expect(mockListCodeqlDatabases).not.toHaveBeenCalled();
expect(setDownload).not.toHaveBeenCalled();
});
describe("when answering not now to prompt", () => {
beforeEach(() => {
showNeverAskAgainDialogSpy.mockResolvedValue("Not now");
});

describe("when answering not now to prompt", () => {
beforeEach(() => {
showNeverAskAgainDialogSpy.mockResolvedValue("Not now");
});

it("returns undefined", async () => {
const result = await listDatabases(owner, repo, credentials, config);
expect(result).toEqual(undefined);
expect(showNeverAskAgainDialogSpy).toHaveBeenCalled();
expect(appMockListCodeqlDatabases).toHaveBeenCalled();
expect(mockListCodeqlDatabases).not.toHaveBeenCalled();
expect(setDownload).not.toHaveBeenCalled();
});
it("returns undefined", async () => {
const result = await listDatabases(owner, repo, credentials, config);
expect(result).toEqual(undefined);
expect(showNeverAskAgainDialogSpy).toHaveBeenCalled();
expect(mockListCodeqlDatabases).not.toHaveBeenCalled();
expect(setDownload).not.toHaveBeenCalled();
});
});

describe("when answering never to prompt", () => {
beforeEach(() => {
showNeverAskAgainDialogSpy.mockResolvedValue("Never");
});
describe("when answering never to prompt", () => {
beforeEach(() => {
showNeverAskAgainDialogSpy.mockResolvedValue("Never");
});

it("returns undefined and sets the config to 'never'", async () => {
const result = await listDatabases(owner, repo, credentials, config);
expect(result).toEqual(undefined);
expect(showNeverAskAgainDialogSpy).toHaveBeenCalled();
expect(appMockListCodeqlDatabases).toHaveBeenCalled();
expect(mockListCodeqlDatabases).not.toHaveBeenCalled();
expect(setDownload).toHaveBeenCalledWith("never");
});
it("returns undefined and sets the config to 'never'", async () => {
const result = await listDatabases(owner, repo, credentials, config);
expect(result).toEqual(undefined);
expect(showNeverAskAgainDialogSpy).toHaveBeenCalled();
expect(mockListCodeqlDatabases).not.toHaveBeenCalled();
expect(setDownload).toHaveBeenCalledWith("never");
});
});
});
Expand Down

0 comments on commit b855574

Please sign in to comment.